やりたかったこと
- numpy2次元配列を3次元Onehot表現
- (Qiitaに初投稿😀)
- より良い書き方を知りたいです!詳しい方教えて下さい。
想定する場面
- 画像セグメンテーションであれば
- 2次元配列は「正解ラベルの元画像」に相当
- 3次元配列は「加工済の正解ラベル」に相当
コード
import numpy as np
# クラス数
class_num = 4
# 各要素がクラスに該当する2次元の配列
arr_2d = np.array([[0, 1]
,[2, 1]])
# 各層がクラスに対応する3次元の配列
np.identity(class_num)[arr_2d].transpose(2,0,1)
まだやってないこと
- numpyだけでなくpytorchのTensor配列でも同様のことができそう
- 複数枚の処理