LoginSignup
12
14

More than 5 years have passed since last update.

実録!コードリーディング入門 ~ NumPy rollaxis & transpose 編 ~ (1/2)

Last updated at Posted at 2017-10-08

この記事は、 numpy.rollaxis の実装を追った際の調査過程をメモしたものです。
コードリーディングの過程がわかる情報というのはありそうでなかなかないので、共有します。
コードリーディング初心者の方のお役に立てば幸いです。

きっかけ

きっかけは Chainer の softmax_cross_entropy の実装を読んでいたところから始まります。
numpy.rollaxis という関数を読んでいることがわかり、これは何をやっているんだろうと疑問に思いました。

    def forward_cpu(self, inputs):
        x, t = inputs
        if chainer.is_debug():
            _check_input_values(x, t, self.ignore_label)

        log_y = log_softmax._log_softmax(x)
        if self.cache_score:
            self.y = numpy.exp(log_y)
        if self.class_weight is not None:
            shape = [1 if d != 1 else -1 for d in six.moves.range(x.ndim)]
            log_y *= _broadcast_to(self.class_weight.reshape(shape), x.shape)
        log_yd = numpy.rollaxis(log_y, 1)                                       これ
        log_yd = log_yd.reshape(len(log_yd), -1)
        log_p = log_yd[numpy.maximum(t.ravel(), 0), numpy.arange(t.size)]
...

ドキュメントを読んで見る

とりあえず NumPyのドキュメント を読んでみると、次の一行がサラッと書かれているだけです。

Roll the specified axis backwards, until it lies in a given position.

軸をロールすると言われても、具体的にデータがどう見えるのかがピンとこなかったため、コードを読んでみることにしてみました。

numpy.rollaxis のコードを読む

実装箇所の特定

まず、git から NumPy のソースコードを拾います。

$ git clone https://github.com/numpy/numpy.git

私はよく find & grep コマンドを使うので、 rollaxis を定義している箇所を探してみます。
これらのコマンドに渡すパラメータで、いくつかポイントがあります。
* -type f で、検索対象をファイルに絞る。そうしないと、ディレクトリに対する grep が走ってしまい、エラー出力がやかましくなります。
* grep には -nH オプションを付けて、ファイル名と行数を表示させる。
*
上記のポイントを踏まえてコマンドを実行したところ、下記のような出力が得られました。

$ find -type f -exec grep -nH 'def rollaxis' {} \;
./numpy/core/numeric.py:1415:def rollaxis(a, axis, start=0):

numpy/core/numeric.py の1415行目で定義されているということなので、見てみましょう。

numpy/core/numeric.py
def rollaxis(a, axis, start=0):
    """
    Roll the specified axis backwards, until it lies in a given position.

    This function continues to be supported for backward compatibility, but you
    should prefer `moveaxis`. The `moveaxis` function was added in NumPy
    1.11.

    Parameters
    ----------
    a : ndarray
        Input array.
    axis : int
        The axis to roll backwards.  The positions of the other axes do not
        change relative to one another.
    start : int, optional
        The axis is rolled until it lies before this position.  The default,
        0, results in a "complete" roll.

    Returns
    -------
    res : ndarray
        For NumPy >= 1.10.0 a view of `a` is always returned. For earlier
        NumPy versions a view of `a` is returned only if the order of the
        axes is changed, otherwise the input array is returned.

    See Also
    --------
    moveaxis : Move array axes to new positions.
    roll : Roll the elements of an array by a number of positions along a
        given axis.

    Examples
    --------
    >>> a = np.ones((3,4,5,6))
    >>> np.rollaxis(a, 3, 1).shape
    (3, 6, 4, 5)
    >>> np.rollaxis(a, 2).shape
    (5, 3, 4, 6)
    >>> np.rollaxis(a, 1, 4).shape
    (3, 5, 6, 4)

    """
    n = a.ndim
    axis = normalize_axis_index(axis, n)
    if start < 0:
        start += n
    msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
    if not (0 <= start < n + 1):
        raise AxisError(msg % ('start', -n, 'start', n + 1, start))
    if axis < start:
        # it's been removed
        start -= 1
    if axis == start:
        return a[...]
    axes = list(range(0, n))
    axes.remove(axis)
    axes.insert(start, axis)
    return a.transpose(axes)

NumPy のコードはドキュメントも含んでいるため、コメントがやたら長いです。
コードの中身としては、
1. axes という (0, 1, ..., n-1) のリストを作る (n: もとの配列の次元)
2. axis を取り除く
3. start 番地に axis を追加
4. ndarray.transpose を呼び出す
となります。

やっていることは axis 番目の軸と start 番目の軸を交換するという処理です。
start のデフォルトが0で、axis 番目の軸を先頭に持ってくるという処理が、ロールしているように感じられるのでしょう。

numpy.ndarray.transpose のコードを読む

numpy.rollaxis の処理は、実際は transpose であるということがわかりました。
2次元行列に対する転置であれば処理内容は容易に想像がつきますが、3次元以上の配列に対する転置は、一体どういう処理をしているのでしょうか?
コードを読んで見るために、まずは find & grep でコードを探してみます。

$ find -type f -exec grep -nH 'def transpose' {} \;
./numpy/core/fromnumeric.py:499:def transpose(a, axes=None):
./numpy/linalg/linalg.py:227:def transpose(a):
./numpy/ma/core.py:6863:def transpose(a, axes=None):

3つ出てきましたが、ここは rollaxis と同じ core ディレクトリを掘るのが良さそうです。
というわけで、 frumnumeric.py の 499 行目を見てみます。

numpy/core/fromnumeric.py
def transpose(a, axes=None):
    """
    Permute the dimensions of an array.

    Parameters
    ----------
    a : array_like
        Input array.
    axes : list of ints, optional
        By default, reverse the dimensions, otherwise permute the axes
        according to the values given.

    Returns
    -------
    p : ndarray
        `a` with its axes permuted.  A view is returned whenever
        possible.

    See Also
    --------
    moveaxis
    argsort

    Notes
    -----
    Use `transpose(a, argsort(axes))` to invert the transposition of tensors
    when using the `axes` keyword argument.

    Transposing a 1-D array returns an unchanged view of the original array.

    Examples
    --------
    >>> x = np.arange(4).reshape((2,2))
    >>> x
    array([[0, 1],
           [2, 3]])

    >>> np.transpose(x)
    array([[0, 2],
           [1, 3]])

    >>> x = np.ones((1, 2, 3))
    >>> np.transpose(x, (1, 0, 2)).shape
    (2, 1, 3)

    """
    return _wrapfunc(a, 'transpose', axes)

これまたコメントがやたら長いですが、やっているのは _wrapfunc を呼ぶだけです。
どんどんほっていきましょう。

numpy/core/fromnumeric.py
def _wrapfunc(obj, method, *args, **kwds):
    try:
        return getattr(obj, method)(*args, **kwds)

    # An AttributeError occurs if the object does not have
    # such a method in its class.

    # A TypeError occurs if the object does have such a method
    # in its class, but its signature is not identical to that
    # of NumPy's. This situation has occurred in the case of
    # a downstream library like 'pandas'.
    except (AttributeError, TypeError):
        return _wrapit(obj, method, *args, **kwds)

こうやって見るとオブジェクトのメソッドを呼んでいるだけです。
しかし、そもそも numpy.ndarray のコードをよんでいるはずなので、何か辻褄が合わない…
それもそのはず、これは numpy.transpose のコードだからです。
それでは、numpy.ndarray.transpose のコードはどこにあるのでしょうか?

numpy.ndarray のありかを探す

まずは、 ndarray のありかを探す必要があるようです。
$ find -name '*.py' -exec grep -nH 'ndarray' {} \; して、ndarray を定義している箇所を探します。
流石にこのキーワードだと沢山 Hit しますので、気合でそれっぽい箇所を探します。
大抵がコメントなので、コメントっぽい箇所を流し読みして抜粋したのが以下のものです。

./numpy/core/memmap.py:4:from .numeric import uint8, ndarray, dtype
/numpy/core/numeric.py:18:    min_scalar_type, ndarray, nditer, nested_iters, promote_types,
./numpy/core/numeric.py:48:    'newaxis', 'ndarray', 'flatiter', 'nditer', 'nested_iters', 'ufunc',

どうやら、 numeric.py になにかいそうなので読んでみます。

numpy/core/numeric.py
from __future__ import division, absolute_import, print_function

import collections
import itertools
import operator
import sys
import warnings

import numpy as np
from . import multiarray
from .multiarray import (
    _fastCopyAndTranspose as fastCopyAndTranspose, ALLOW_THREADS,
    BUFSIZE, CLIP, MAXDIMS, MAY_SHARE_BOUNDS, MAY_SHARE_EXACT, RAISE,
    WRAP, arange, array, broadcast, can_cast, compare_chararrays,
    concatenate, copyto, count_nonzero, dot, dtype, empty,
    empty_like, flatiter, frombuffer, fromfile, fromiter, fromstring,
    inner, int_asbuffer, lexsort, matmul, may_share_memory,
    min_scalar_type, ndarray, nditer, nested_iters, promote_types,
    putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
    zeros, normalize_axis_index)

すると、今度は multiarray.py を読もうとしたくなるのが人情ですが、残念ながらそのようなファイルはありません。
この場合、Cで書かれた拡張モジュールを使っている可能性が大です。
実際、NumPy は計算部分がCで書かれているから速いとうたわれているライブラリですので、次はCの世界を掘り下げていくことになります。

(補足)
pip でインストールした NumPy のディレクトリを見てみると、たしかに multiarray.so が存在しており、C側のモジュールを読んでいることがわかります。

~/.pyenv/versions/anaconda3-4.1.1/pkgs/numpy-1.11.1-py35_0/lib/python3.5/site-packages/numpy/core$ ls
arrayprint.py    fromnumeric.py         include       lib          multiarray.so        operand_flag_tests.so  setup.py              tests
cversions.py     function_base.py       info.py       machar.py    multiarray_tests.so  __pycache__            shape_base.py         umath.so
defchararray.py  generate_numpy_api.py  __init__.py   memmap.py    numeric.py           records.py             struct_ufunc_test.so  umath_tests.so
_dummy.so        getlimits.py           _internal.py  _methods.py  numerictypes.py      setup_common.py        test_rational.so

Cを掘り下げる

Cファイルの何処かで、 transpose という文字列を定義している箇所があるはずです。
そこで、 transpose を定義している箇所を探します。

$ find . -type f -exec grep -nH '"transpose"' {} \;
./numpy/core/src/multiarray/methods.c:2707:    {"transpose",
./numpy/core/src/multiarray/scalartypes.c.src:1475:        "transpose",
./numpy/core/src/multiarray/scalartypes.c.src:1955:    {"transpose",

numpy/core/src/multiarray/methods.c の 2707 行目付近を見てみましょう。

numpy/core/src/multiarray/methods.c
NPY_NO_EXPORT PyMethodDef array_methods[] = {

...

    {"transpose",
        (PyCFunction)array_transpose,
        METH_VARARGS, NULL},
    {"var",
        (PyCFunction)array_variance,
        METH_VARARGS | METH_KEYWORDS, NULL},
    {"view",
        (PyCFunction)array_view,
        METH_VARARGS | METH_KEYWORDS, NULL},
    {NULL, NULL, 0, NULL}           /* sentinel */
};

この記述から、 transpose メソッドが呼び出されると、array_transpose 関数が呼ばれることがわかりました。
次は、 array_transpose を読んでいくことになります。

array_transpose 関数を読む

まず、 array_transpose 関数の在り処を探します。
どうやら同じファイルの 1988 行目にあるようです。

$ find -name '*.c' -exec grep -nH array_transpose {} \;
./numpy/core/src/multiarray/getset.c:927:array_transpose_get(PyArrayObject *self)
./numpy/core/src/multiarray/getset.c:999:        (getter)array_transpose_get,
./numpy/core/src/multiarray/methods.c:1988:array_transpose(PyArrayObject *self, PyObject *args)
./numpy/core/src/multiarray/methods.c:2708:        (PyCFunction)array_transpose,

1988 行目を見ると次のようになっていました。

numpy/core/src/multiarray/methods.c
static PyObject *
array_transpose(PyArrayObject *self, PyObject *args)
{
    PyObject *shape = Py_None;
    Py_ssize_t n = PyTuple_Size(args);
    PyArray_Dims permute;
    PyObject *ret;

    if (n > 1) {
        shape = args;
    }
    else if (n == 1) {
        shape = PyTuple_GET_ITEM(args, 0);
    }

    if (shape == Py_None) {
        ret = PyArray_Transpose(self, NULL);
    }
    else {
        if (!PyArray_IntpConverter(shape, &permute)) {
            return NULL;
        }
        ret = PyArray_Transpose(self, &permute);
        npy_free_cache_dim_obj(permute);
    }

    return ret;
}

transpose 処理の本体は PyArray_Transpose 関数であることがわかりました。

PyArray_Transpose 関数を読む

例によって grep & find します。

$ find -name '*.c' -exec grep -nH PyArray_Transpose {} \;
./numpy/core/src/multiarray/calculation.c:71:        op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
./numpy/core/src/multiarray/calculation.c:186:        op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
./numpy/core/src/multiarray/convert.c:333:            new = PyArray_Transpose(self, NULL);
./numpy/core/src/multiarray/getset.c:929:    return PyArray_Transpose(self, NULL);
./numpy/core/src/multiarray/mapping.c:137:    new = PyArray_Transpose(*ret, &permute);
./numpy/core/src/multiarray/methods.c:2003:        ret = PyArray_Transpose(self, NULL);
./numpy/core/src/multiarray/methods.c:2009:        ret = PyArray_Transpose(self, &permute);
./numpy/core/src/multiarray/multiarraymodule.c:928:        ap2t = PyArray_Transpose(ap2, &newaxes);
./numpy/core/src/multiarray/multiarraymodule.c:1146:        tmp = (PyArrayObject *)PyArray_Transpose(arr, &new_axes);
./numpy/core/src/multiarray/shape.c:675:    return PyArray_Transpose(ap, &new_axes);
./numpy/core/src/multiarray/shape.c:683:PyArray_Transpose(PyArrayObject *ap, PyArray_Dims *permute)

numpy/core/src/multiarray/shape.c にお目当てのポイントが有ることがわかります。
だいぶ長くなったので、ここで一旦切りたいと思います。
疑問点等あればコメントください。

追記: 後編 を書きました

だいぶ時間が立ってしまいましたが、後編の記事として 実録!コードリーディング入門 ~ NumPy rollaxis & transpose 編 ~ (2/2) を公開しました!
こちらも合わせてぜひお読みください m(_ _)m

12
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
14