この記事は、 numpy.rollaxis
の実装を追った際の調査過程をメモしたものです。
コードリーディングの過程がわかる情報というのはありそうでなかなかないので、共有します。
コードリーディング初心者の方のお役に立てば幸いです。
きっかけ
きっかけは Chainer の softmax_cross_entropy
の実装を読んでいたところから始まります。
numpy.rollaxis
という関数を読んでいることがわかり、これは何をやっているんだろうと疑問に思いました。
def forward_cpu(self, inputs):
x, t = inputs
if chainer.is_debug():
_check_input_values(x, t, self.ignore_label)
log_y = log_softmax._log_softmax(x)
if self.cache_score:
self.y = numpy.exp(log_y)
if self.class_weight is not None:
shape = [1 if d != 1 else -1 for d in six.moves.range(x.ndim)]
log_y *= _broadcast_to(self.class_weight.reshape(shape), x.shape)
log_yd = numpy.rollaxis(log_y, 1) ★ これ
log_yd = log_yd.reshape(len(log_yd), -1)
log_p = log_yd[numpy.maximum(t.ravel(), 0), numpy.arange(t.size)]
...
ドキュメントを読んで見る
とりあえず NumPyのドキュメント を読んでみると、次の一行がサラッと書かれているだけです。
Roll the specified axis backwards, until it lies in a given position.
軸をロールすると言われても、具体的にデータがどう見えるのかがピンとこなかったため、コードを読んでみることにしてみました。
numpy.rollaxis
のコードを読む
実装箇所の特定
まず、git から NumPy のソースコードを拾います。
$ git clone https://github.com/numpy/numpy.git
私はよく find
& grep
コマンドを使うので、 rollaxis
を定義している箇所を探してみます。
これらのコマンドに渡すパラメータで、いくつかポイントがあります。
-
-type f
で、検索対象をファイルに絞る。そうしないと、ディレクトリに対するgrep
が走ってしまい、エラー出力がやかましくなります。 -
grep
には-nH
オプションを付けて、ファイル名と行数を表示させる。
上記のポイントを踏まえてコマンドを実行したところ、下記のような出力が得られました。
$ find -type f -exec grep -nH 'def rollaxis' {} \;
./numpy/core/numeric.py:1415:def rollaxis(a, axis, start=0):
numpy/core/numeric.py
の1415行目で定義されているということなので、見てみましょう。
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
This function continues to be supported for backward compatibility, but you
should prefer `moveaxis`. The `moveaxis` function was added in NumPy
1.11.
Parameters
----------
a : ndarray
Input array.
axis : int
The axis to roll backwards. The positions of the other axes do not
change relative to one another.
start : int, optional
The axis is rolled until it lies before this position. The default,
0, results in a "complete" roll.
Returns
-------
res : ndarray
For NumPy >= 1.10.0 a view of `a` is always returned. For earlier
NumPy versions a view of `a` is returned only if the order of the
axes is changed, otherwise the input array is returned.
See Also
--------
moveaxis : Move array axes to new positions.
roll : Roll the elements of an array by a number of positions along a
given axis.
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
"""
n = a.ndim
axis = normalize_axis_index(axis, n)
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
if not (0 <= start < n + 1):
raise AxisError(msg % ('start', -n, 'start', n + 1, start))
if axis < start:
# it's been removed
start -= 1
if axis == start:
return a[...]
axes = list(range(0, n))
axes.remove(axis)
axes.insert(start, axis)
return a.transpose(axes)
NumPy のコードはドキュメントも含んでいるため、コメントがやたら長いです。
コードの中身としては、
-
axes
という (0, 1, ..., n-1) のリストを作る (n: もとの配列の次元) -
axis
を取り除く -
start
番地にaxis
を追加 -
ndarray.transpose
を呼び出す
となります。
やっていることは axis
番目の軸と start
番目の軸を交換するという処理です。
start
のデフォルトが0で、axis
番目の軸を先頭に持ってくるという処理が、ロールしているように感じられるのでしょう。
numpy.ndarray.transpose
のコードを読む
numpy.rollaxis
の処理は、実際は transpose
であるということがわかりました。
2次元行列に対する転置であれば処理内容は容易に想像がつきますが、3次元以上の配列に対する転置は、一体どういう処理をしているのでしょうか?
コードを読んで見るために、まずは find & grep
でコードを探してみます。
$ find -type f -exec grep -nH 'def transpose' {} \;
./numpy/core/fromnumeric.py:499:def transpose(a, axes=None):
./numpy/linalg/linalg.py:227:def transpose(a):
./numpy/ma/core.py:6863:def transpose(a, axes=None):
3つ出てきましたが、ここは rollaxis
と同じ core
ディレクトリを掘るのが良さそうです。
というわけで、 frumnumeric.py
の 499 行目を見てみます。
def transpose(a, axes=None):
"""
Permute the dimensions of an array.
Parameters
----------
a : array_like
Input array.
axes : list of ints, optional
By default, reverse the dimensions, otherwise permute the axes
according to the values given.
Returns
-------
p : ndarray
`a` with its axes permuted. A view is returned whenever
possible.
See Also
--------
moveaxis
argsort
Notes
-----
Use `transpose(a, argsort(axes))` to invert the transposition of tensors
when using the `axes` keyword argument.
Transposing a 1-D array returns an unchanged view of the original array.
Examples
--------
>>> x = np.arange(4).reshape((2,2))
>>> x
array([[0, 1],
[2, 3]])
>>> np.transpose(x)
array([[0, 2],
[1, 3]])
>>> x = np.ones((1, 2, 3))
>>> np.transpose(x, (1, 0, 2)).shape
(2, 1, 3)
"""
return _wrapfunc(a, 'transpose', axes)
これまたコメントがやたら長いですが、やっているのは _wrapfunc
を呼ぶだけです。
どんどんほっていきましょう。
def _wrapfunc(obj, method, *args, **kwds):
try:
return getattr(obj, method)(*args, **kwds)
# An AttributeError occurs if the object does not have
# such a method in its class.
# A TypeError occurs if the object does have such a method
# in its class, but its signature is not identical to that
# of NumPy's. This situation has occurred in the case of
# a downstream library like 'pandas'.
except (AttributeError, TypeError):
return _wrapit(obj, method, *args, **kwds)
こうやって見るとオブジェクトのメソッドを呼んでいるだけです。
しかし、そもそも numpy.ndarray
のコードをよんでいるはずなので、何か辻褄が合わない…
それもそのはず、これは numpy.transpose
のコードだからです。
それでは、numpy.ndarray.transpose
のコードはどこにあるのでしょうか?
numpy.ndarray
のありかを探す
まずは、 ndarray
のありかを探す必要があるようです。
$ find -name '*.py' -exec grep -nH 'ndarray' {} \;
して、ndarray を定義している箇所を探します。
流石にこのキーワードだと沢山 Hit しますので、気合でそれっぽい箇所を探します。
大抵がコメントなので、コメントっぽい箇所を流し読みして抜粋したのが以下のものです。
./numpy/core/memmap.py:4:from .numeric import uint8, ndarray, dtype
/numpy/core/numeric.py:18: min_scalar_type, ndarray, nditer, nested_iters, promote_types,
./numpy/core/numeric.py:48: 'newaxis', 'ndarray', 'flatiter', 'nditer', 'nested_iters', 'ufunc',
どうやら、 numeric.py
になにかいそうなので読んでみます。
from __future__ import division, absolute_import, print_function
import collections
import itertools
import operator
import sys
import warnings
import numpy as np
from . import multiarray
from .multiarray import (
_fastCopyAndTranspose as fastCopyAndTranspose, ALLOW_THREADS,
BUFSIZE, CLIP, MAXDIMS, MAY_SHARE_BOUNDS, MAY_SHARE_EXACT, RAISE,
WRAP, arange, array, broadcast, can_cast, compare_chararrays,
concatenate, copyto, count_nonzero, dot, dtype, empty,
empty_like, flatiter, frombuffer, fromfile, fromiter, fromstring,
inner, int_asbuffer, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
zeros, normalize_axis_index)
すると、今度は multiarray.py
を読もうとしたくなるのが人情ですが、残念ながらそのようなファイルはありません。
この場合、Cで書かれた拡張モジュールを使っている可能性が大です。
実際、NumPy は計算部分がCで書かれているから速いとうたわれているライブラリですので、次はCの世界を掘り下げていくことになります。
(補足)
pip
でインストールした NumPy のディレクトリを見てみると、たしかに multiarray.so
が存在しており、C側のモジュールを読んでいることがわかります。
~/.pyenv/versions/anaconda3-4.1.1/pkgs/numpy-1.11.1-py35_0/lib/python3.5/site-packages/numpy/core$ ls
arrayprint.py fromnumeric.py include lib multiarray.so operand_flag_tests.so setup.py tests
cversions.py function_base.py info.py machar.py multiarray_tests.so __pycache__ shape_base.py umath.so
defchararray.py generate_numpy_api.py __init__.py memmap.py numeric.py records.py struct_ufunc_test.so umath_tests.so
_dummy.so getlimits.py _internal.py _methods.py numerictypes.py setup_common.py test_rational.so
Cを掘り下げる
Cファイルの何処かで、 transpose
という文字列を定義している箇所があるはずです。
そこで、 transpose
を定義している箇所を探します。
$ find . -type f -exec grep -nH '"transpose"' {} \;
./numpy/core/src/multiarray/methods.c:2707: {"transpose",
./numpy/core/src/multiarray/scalartypes.c.src:1475: "transpose",
./numpy/core/src/multiarray/scalartypes.c.src:1955: {"transpose",
numpy/core/src/multiarray/methods.c
の 2707 行目付近を見てみましょう。
NPY_NO_EXPORT PyMethodDef array_methods[] = {
...
{"transpose",
(PyCFunction)array_transpose,
METH_VARARGS, NULL},
{"var",
(PyCFunction)array_variance,
METH_VARARGS | METH_KEYWORDS, NULL},
{"view",
(PyCFunction)array_view,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL} /* sentinel */
};
この記述から、 transpose
メソッドが呼び出されると、array_transpose
関数が呼ばれることがわかりました。
次は、 array_transpose
を読んでいくことになります。
array_transpose
関数を読む
まず、 array_transpose
関数の在り処を探します。
どうやら同じファイルの 1988 行目にあるようです。
$ find -name '*.c' -exec grep -nH array_transpose {} \;
./numpy/core/src/multiarray/getset.c:927:array_transpose_get(PyArrayObject *self)
./numpy/core/src/multiarray/getset.c:999: (getter)array_transpose_get,
./numpy/core/src/multiarray/methods.c:1988:array_transpose(PyArrayObject *self, PyObject *args)
./numpy/core/src/multiarray/methods.c:2708: (PyCFunction)array_transpose,
1988 行目を見ると次のようになっていました。
static PyObject *
array_transpose(PyArrayObject *self, PyObject *args)
{
PyObject *shape = Py_None;
Py_ssize_t n = PyTuple_Size(args);
PyArray_Dims permute;
PyObject *ret;
if (n > 1) {
shape = args;
}
else if (n == 1) {
shape = PyTuple_GET_ITEM(args, 0);
}
if (shape == Py_None) {
ret = PyArray_Transpose(self, NULL);
}
else {
if (!PyArray_IntpConverter(shape, &permute)) {
return NULL;
}
ret = PyArray_Transpose(self, &permute);
npy_free_cache_dim_obj(permute);
}
return ret;
}
transpose 処理の本体は PyArray_Transpose
関数であることがわかりました。
PyArray_Transpose
関数を読む
例によって grep & find
します。
$ find -name '*.c' -exec grep -nH PyArray_Transpose {} \;
./numpy/core/src/multiarray/calculation.c:71: op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
./numpy/core/src/multiarray/calculation.c:186: op = (PyArrayObject *)PyArray_Transpose(ap, &newaxes);
./numpy/core/src/multiarray/convert.c:333: new = PyArray_Transpose(self, NULL);
./numpy/core/src/multiarray/getset.c:929: return PyArray_Transpose(self, NULL);
./numpy/core/src/multiarray/mapping.c:137: new = PyArray_Transpose(*ret, &permute);
./numpy/core/src/multiarray/methods.c:2003: ret = PyArray_Transpose(self, NULL);
./numpy/core/src/multiarray/methods.c:2009: ret = PyArray_Transpose(self, &permute);
./numpy/core/src/multiarray/multiarraymodule.c:928: ap2t = PyArray_Transpose(ap2, &newaxes);
./numpy/core/src/multiarray/multiarraymodule.c:1146: tmp = (PyArrayObject *)PyArray_Transpose(arr, &new_axes);
./numpy/core/src/multiarray/shape.c:675: return PyArray_Transpose(ap, &new_axes);
./numpy/core/src/multiarray/shape.c:683:PyArray_Transpose(PyArrayObject *ap, PyArray_Dims *permute)
numpy/core/src/multiarray/shape.c
にお目当てのポイントが有ることがわかります。
だいぶ長くなったので、ここで一旦切りたいと思います。
疑問点等あればコメントください。
追記: 後編 を書きました
だいぶ時間が立ってしまいましたが、後編の記事として 実録!コードリーディング入門 ~ NumPy rollaxis & transpose 編 ~ (2/2) を公開しました!
こちらも合わせてぜひお読みください m(_ _)m