はじめに
numpy.ndarray
には、インデックス参照、スライス、ブロードキャストといったさまざま機能があります。
本記事では、機械学習でデータを扱う時に遭遇しそうな具体的な状況設定のもとで、それらの便利機能を利用した書き方を紹介します。
目標としては、これらを使うことでシンプルなコードを書くとともに、遅いと言われているfor文をなくしてパフォーマンスを上げることを目指します。
対象読者
-
numpy
の色々な書き方を知りたい方 -
ndarray
を使ったコードのパフォーマンスを少しでも上げたい方
ドキュメントを読むことに抵抗がない方は、こちらの"Less Basic"の章などを読んでみると発見があるかもしれません。
環境
Python 3.10.13
numpy 1.26.0
行と列を抽出する
二次元配列x
から、任意の行と列を同時に指定して抜き出す方法です。
紹介する3つの例のうち、下の2つの方が望ましいと言われています(1つ目の例だと、ビューが2度作成されてどうのこうの...と見たことがありますが真偽不明です)。
import numpy as np
x = np.arange(12).reshape(3, 4)
# array([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
# 0,2行目、1,3列目を抽出
x[[0, 2]][:, [1, 3]] # これより
x[[[0], [2]], [1, 3]] # これと
x[np.ix_([0, 2], [1, 3])] # これの方が望ましい(と言われている)
# いずれも以下の同じ結果が得られる
# array([[ 1, 3],
# [ 9, 11]])
ix_
関数は、N
個の1次元要素列A_i
を受け取って、それぞれを(1, ..., len(A_i), ..., 1)
にサイズ変更したN
次元arrayをN
個返す関数です。
公式ドキュメントには、以下のように記されています。
Using
ix_
one can quickly construct index arrays that will index the cross product.a[np.ix_([1,3],[2,5])] returns the array [[a[1,2] a[1,5]], [a[3,2] a[3,5]]]
.
np.ix_([0, 2], [1, 3])
# (array([[0],
# [2]]),
# array([[1, 3]]))
argsortしつつ、ソートされた配列も得る
ある配列をソートした結果の配列と、並べ替えた添え字配列のどちらも得たい場合があると思います。その場合には、argsort
で得られた添え字配列からソート済み配列を構成することで、argsort
とsort
をどちらも実行してしまう無駄を省けます。
x = np.array([[1, 5],
[2, 0],
[3, 4]])
idx = x.argsort(axis=1)
# array([[0, 1],
# [1, 0],
# [0, 1]])
np.take_along_axis(x, idx, axis=1)
# array([[1, 5],
# [0, 2],
# [3, 4]])
# np.sort(x, axis=1)と同じ結果になる
応用例として、スライスで特定の列を抜き出してからargsort
することで、その列をキーとしてソートすることができます。
idx = x[:, 1].argsort(axis=0)
# array([1, 2, 0])
x[idx]
# array([[2, 0],
# [3, 4],
# [1, 5]])
# 2列目の大小で並べ替えられた
この方法は、他にもいくらでも応用がききます。例えば、各行の和の大小にしたがってソートすることもお手のものです。
x[x.sum(axis=1).argsort(axis=0)]
# array([[2, 0], 2 + 0 = 2
# [1, 5], 1 + 5 = 6
# [3, 4]]) 3 + 4 = 7
各行から別の配列に従って要素を取り出す
outputs
をn*m
要素の2次元配列、labels
をn
要素の1次元配列とします。
各i
に対し、outputs[i, labels[i]]
を取り出して新たな1次元配列を作成する方法です。
ちなみに、outputs
はshape=(バッチサイズ, クラス数)
の多クラス分類器の出力を、labels
はshape=(バッチサイズ,)
の正解ラベルを想定しています。そして、バッチ内の各要素の正解クラスに対する予測(例えばsoftmax出力)を取り出したい状況を考えます。
outputs = np.arange(12).reshape(3, 4)
# array([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
labels = np.array([2, 0, 1])
outputs[np.arange(len(outputs)), labels]
# array([2, 4, 9])
# つまり、
# [ 0, 1, 2, 3][2] -> 2
# [ 4, 5, 6, 7][0] -> 4
# [ 8, 9, 10, 11][1] -> 9
出力形状は違えど、以下のように書くこともできます。
np.take_along_axis(outputs, labels.reshape(-1, 1), axis=1)
# array([[2],
# [4],
# [9]])
ある要素より大きい要素が配列内にいくつあるかを求める
outputs
、labels
は同上とします。
各i
に対し、正解クラスlabels[i]
に対するモデル出力outputs[i, labels[i]]
が、全クラスに対する出力outputs[i]
の中で何番目に大きいかを求める方法です。
geq = outputs >= np.take_along_axis(outputs, labels.reshape(-1, 1), axis=1)
# array([[ 0, 1, 2, 3], array([[2],
# [ 4, 5, 6, 7], >= [4],
# [ 8, 9, 10, 11]]) [9]])
# =
# array([[False, False, True, True],
# [ True, True, True, True],
# [False, True, True, True]])
geq.sum(axis=1)
# array([2, 4, 3])
# 2は[ 0, 1, 2, 3]の中で2番目に、
# 4は[ 4, 5, 6, 7]の中で4番目に、
# 9は[ 8, 9, 10, 11]の中で3番目に大きい
セグメンテーション結果を色分け表示する
ピクセルごとにクラス分け(以下の例では3クラス)されたセグメンテーション結果img
から、クラスごとに色付けされた画像を生成する方法です。palette[i]
は、クラスi
の色に対応するRGB値を表します。
やっていることは配列によるインデックス参照にすぎないのですが、意外と思いつきにくいと思っています。
img = np.array([[0, 1, 1], [1, 0, 2]])
palette = np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255]])
palette[img]
# array([[[255, 0, 0],
# [ 0, 255, 0],
# [ 0, 255, 0]],
# [[ 0, 255, 0],
# [255, 0, 0],
# [ 0, 0, 255]]])
# つまり、
# [[R, G, G],
# [G, R, B]]
各クラスごとのマスクを得る
上と同じセグメンテーション結果img
から、各i
ごとに、img == i
で得られるbool
配列を計算する方法です。
要するに、以下の操作で得られるmask2
と同じ配列を計算したいです。
masks2 = np.empty((3, 2, 3), dtype=bool)
for i in range(3):
for j in range(2):
for k in range(3):
masks2[i][j][k] = (img[j][k] == i)
これは、以下のようにブロードキャストを応用することで実現できます。
classes = np.arange(3)
# array([0, 1, 2])
masks = img == classes.reshape(-1, 1, 1)
# array([[[ True, False, False],
# [False, True, False]],
# [[False, True, True],
# [ True, False, False]],
# [[False, False, False],
# [False, False, True]]])
ちなみに、masks.sum(axis=(1, 2))
によって、クラスごとのピクセル数を求めることができますが、この計算のみが目的なら、np.bincount
を使ってimg
から直接計算した方が無駄がないです。
masks.sum(axis=(1, 2))
# array([2, 3, 1])
np.bincount(img.flatten())
# array([2, 3, 1])
終わりに
ドキュメントは抽象的に書かれていることが多いので、反対に、具体的な例を持ち出しながらndarray
の便利演算を紹介してみました。この記事で、少しでもndarray
への理解が深まれば幸いです。