zerowarurei
@zerowarurei

Are you sure you want to delete the question?

Leaving a resolved question undeleted may help others!

numpyでの多重配列内での比較

Q&A

Closed

解決したいこと

trg=np.array([1,2,3])
arr=np.array([[1,2,3],[2,3,4],[1,2,3]])

という二つの配列があった場合に、arr内の各要素がtrgと一致するかの判定をできるだけ速く行いたいです。
(説明下手で申し訳ないですが、この場合[True, False, True]が得たいです。)

自分で試したこと

import numpy as np
import time 

trg=np.array([1,2,3], dtype=np.uint8)
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)

st=time.time()
ans=sum((arr==trg).T)==3
print(ans)
print(f'elapsed time: {time.time()-st}')
[ True False  True ...  True False  True]
elapsed time: 0.036267995834350586

これより速くしたいのですが、いい方法があれば教えてもらえませんでしょうか。

0

3Answer

質問の状況(arrtrgのdtypeが両方np.uint8で、trgの長さが3)であれば、
以下のように一行分を一要素として見るようにviewを設定するのが速いです。

arr.view('V3').ravel() == trg.view('V3')

多様な状況に対応できるように拡張すると以下のようになります(ただしarrtrgのdtypeが異なる場合は失敗することがあります)。

def dim_equal(arr, trg):
    arr = np.ascontiguousarray(arr)
    trg = np.ascontiguousarray(trg)
    new_dtype = np.dtype((np.void, arr.dtype.itemsize * arr.shape[1]))
    arr_v = arr.view(new_dtype).ravel()
    trg_v = trg.view(new_dtype).ravel()
    return arr_v == trg_v

numbaを使えばもっと速くなります。おそらくこの方法が最速だと思われます。

from numba import njit

@njit
def dim_equal(arr, trg):
    m, n = arr.shape
    out = np.empty(m, dtype=np.bool8)
    for i in range(m):
        tmp = 0
        for j in range(n):
            if arr[i, j] == trg[j]:
                tmp += 1
        out[i] = tmp == n
    return out
2Like

Comments

  1. @zerowarurei

    Questioner

    @nkayさん
    回答ありがとうございます。numbaで型指定して実行したところ3倍ほど早く実行できました。viewについても一行分を一要素としてみた方が効率が良いことに納得です。
    教えていただきありがとうございます。

色々と試してみましたが、@zerowarureiさんのプログラムが今のところ一番速いです。

np.sum()を使って処理
import numpy as np
import time 

trg=np.array([1,2,3], dtype=np.uint8)
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)


st=time.time()
ans = np.sum(arr==trg, axis=1) == 3
print(ans)
print(f'elapsed time: {time.time()-st}')
"""
[ True False  True ...  True False  True]
elapsed time: 0.0965876579284668
"""

numpyのsumを使えば、転置しなくても良くて少し速くなると思いましたが、遅かったです。
daskというものがあり、並列処理に便利です。少し使ってみましたが、今回のケースでは速くなりませんでした。

daskを使って並列処理
import numpy as np
import dask.array as da
import time 

trg=np.array([1,2,3], dtype=np.uint8)
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)
darr=da.from_array(arr, chunks=600000)


st=time.time()
ans=sum((darr==trg).T)==3
print(ans.compute())
print(f'elapsed time: {time.time()-st}')
"""
[ True False  True ...  True False  True]
elapsed time: 0.05028700828552246
"""

一応コードの流れを紹介します。

daskを使って最初に5つの配列に分割している
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)
darr=da.from_array(arr, chunks=600000)
darr.visualize()

Unknown-97.png

trgと比べる
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)
darr=da.from_array(arr, chunks=600000)
process = (darr == trg)
process.visualize()

Unknown-98.png

転置する
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)
darr=da.from_array(arr, chunks=600000)
process = (darr == trg)
process_t = process.T
process_t.visualize()

Unknown-99.png

比較する3つの要素を足す
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)
darr=da.from_array(arr, chunks=600000)
process = (darr == trg)
process_t = process.T
process_t_sum = sum(process_t)
process_t_sum.visualize()

Unknown-100.png

3つの要素の和が3かどうか(コード全容)
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)
darr=da.from_array(arr, chunks=600000)
process = (darr == trg)
process_t = process.T
process_t_sum = sum(process_t)
flag = process_t_sum == 3
flag.visualize()

Unknown-96.png

このような処理をしていてます。

1Like

Comments

  1. @zerowarurei

    Questioner

    @yusuke_s_yusukeさん
    丁寧にありがとうございます。
    私の環境下でdaskでchunks=200000としたところ、elapsed time: 0.023667097091674805
    と処理速度が改善しました。
    daskも今回初めて知りました。教えていただき大変ありがとうございます。
  2. それは良かったです。
    私もdask初めて使ってみたので、環境依存(またはchunks依存?)があることを知らなかったので勉強になりました。

質問の回答になっておらず、処理速度では劣りますが、numpyのall()を使った書き方で比較することもできます。

import numpy as np
import time 

trg=np.array([1,2,3], dtype=np.uint8)
arr=np.array([[1,2,3],[2,3,4],[1,2,3]]*1000000, dtype=np.uint8)


st=time.time()
ans = (arr == trg).all(axis=1)
print(ans)
print(f'elapsed time: {time.time()-st}')
[ True False  True ...  True False  True]
elapsed time: 0.08394408226013184

参考:https://teratail.com/questions/140738

0Like

Comments

  1. @zerowarurei

    Questioner

    @yusuke_s_yusukeさん
    回答ありがとうございます。
    numpyのall()にaxis指定できること初めて知りました。勉強になります。

Your answer might help someone💌