88
80

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

AtCoderで Python を高速化する Numpy + Numba を使う

Last updated at Posted at 2019-06-21

元の題名は「AtCoderで言語アップデートがあるので、PythonでNumbaが使える設定を考えてみた」でしたが、Language Test が始まって Numba が使えることになったので、題名を変更して、記事の内容も変更しました。(2020/1/31)

最近、Python のコードを早く書く練習として AtCoder を始めました。AtCoder は、最近人気のようで、モチベーションを保てるので、学習手段としてはかなり有効だと思います。

自分の場合、普段から Python を使っていますが、ライブラリー豊富だし、ボトルネックは最終 Numba 又は Cython を使えばいいので、Python は便利な言語だと思っています。でも、AtCoder をしていると、Python は遅くて C++ の方が向いていると感じるので、C++ を勉強しようかとも思って実際にやってみたのですが、AtCoder に限れば C++ を使うことはそれほど難しくありません。しかし、普段使うのは、C++ よりも Python の方がずっと便利なので、AtCoder の方も引き続き Python でいこうと思っています。その時に AtCoder の方で言語のアップデートが丁度予定されていたのでので、Numba が使える設定を考えてみました。

今回のアップデートで、Numba の JIT が使えるようになるとともに、JIT のキャッシュと AOT も利用することができるようになりました。まず、それらの利用の仕方について簡単に書きます。

ベンチマークの結果

Numba を使うとどの程度の効果があるかを知るために、ベンチマークをしてみました。Python で TLE が続出した AtCoder の AtCoder Beginner Contest 129 の問題D を使っています。データとしては、入力例 2 と制約の最大である縦 2000 行 横 2000 列のグリッドを乱数を使って作成したものを使いました。

種別 入力例2(8x8) 2000x2000 グリッド
Python3.7 24 ms 2263 ms
PyPy3.6 48 ms 533 ms
Numba (JIT) 757 ms 884 ms
Numba (JITキャッシュ有り) 496 ms 627 ms
Numba (AOT) 112 ms 228 ms
Cython 117 ms 228 ms
Pythran 3 ms 202 ms

結果は、上の表のとおりで、Numba を使う効果は大いにあることがわかります。Numba JIT は高速化したい関数に@njit とデコレータを付けるだけでなのですが、それだけで約900ms となって AC にできてしまいます。

このケースでは、JITする関数がそれほど複雑でないためコンパイル時間が約250msなので、それで十分なのですが、コードが複雑になるとコンパイル時間が長くなってしまいます。その場合には、キャッシュ又は AOT が使えます。

キャッシュを使った場合は、約630msで PyPy の実行時間と同じ程度になります。キャッシュを使う場合には、Numpy と Numba の読み込みが必要で最低でも約500msが必要になります。「AtCoderはJavaで2倍の余裕があるようにしています」というルールだそうなので、Java よりも処理は高速になる場合が多いので、ほぼ問題なく AC できるはずです。

AOT を使うと Python3.7よりも約10倍、PyPy3 よりも約3倍高速化することができます。Java と同等の速さになるので、通せない問題は事実上なくなります。なお、AOT でコンパイルしてしまえば Numba のライブラリーは不要になりますが、Numpy は必要なので、最低でも約100msの時間は必要です。

今回のアプデートで Cython も使えるようになりました。Cython も高速で、静的型付言語の経験があれば、Numba よりも使いやすいかもしれません。いずれにしても、Numba と Cython が使えるようになったことで、Python のコードを高速化することが容易になりました。

殆ど知られていないのですが Pythran という Python のコードを C++ に変換するコンパイラがあります。AtCoderでは『Pythonの高速化テクニック:C++で書き直す。』というネタがあるそうですが、まさにそれをやってくれるライブラリーです。科学技術計算用で、Python と Numpy のサブセットに対応しています。Pythran の方では、Cython との統合を考えているようなので(参照 Getting the best of every world: Cython and Pythran working together)、次回の言語アップデート時に、Cython のライブラリーとして使えるかどうか検討すればいいと思っています。もし、うまく統合できれば Cython のコードを書くのがかなり楽になると思います。

サンプルコード

AtCoder Beginner Contest 144 の問題E を使って、それぞれのケースで、どのようにコードを書けばいいのかを説明しておきます。この問題は、Python だけでは AC が難しいですが、Numpy を使うと簡単に AC できるので、あまり適当なケースではないかもしれませんが、コードが比較的短いので使いました。

Numba JIT

Numba を JIT で使う場合には、高速化したい関数に@njit とデコレータを付けるだけで手軽に高速化できます。signature を使って型指定をすることができますが、コンパイルを実行する時点が違うだけで、トータルの処理時間は同じです。

import numpy as np
from numba import njit

@njit
def solve(A, F, K):
    lo = 0
    hi = 10 ** 12
    while lo < hi:
        mid = (lo + hi) // 2
        total_train = 0
        for i in range(A.shape[0]):
            need = A[i] - mid // F[i]
            total_train += need if need > 0 else 0
        if total_train <= K:
            hi = mid
        else:
            lo = mid + 1
    return lo

stdin = np.fromstring(open(0).read(), dtype=np.int64, sep=' ')
N, K = stdin[:2]
A = stdin[2: 2 + N]
F = stdin[2 + N:]
print(solve(np.sort(A)[::-1], np.sort(F), K))

以下のようにJITをするコードを増やすことが可能です。しかし、実行時間が上の場合よりの約 0.7s 遅くなります。np.sort() のコンパイルに時間がかかっているためで、JIT の場合には、コンパイル時間を考慮して、ボトルネックになる最小限のコードだけをJITすることが重要です。

import numpy as np
from numba import njit

@njit
def solve(stdin):
    N, K = stdin[:2]
    A = stdin[2: 2 + N]
    A = np.sort(A)[::-1]
    F = np.sort(stdin[2 + N:])
    lo = 0
    hi = 10 ** 12
    while lo < hi:
        mid = (lo + hi) // 2
        total_train = 0
        for i in range(A.shape[0]):
            need = A[i] - mid // F[i]
            total_train += need if need > 0 else 0
        if total_train <= K:
            hi = mid
        else:
            lo = mid + 1
    return lo

print(solve(np.fromstring(open(0).read(), dtype=np.int64, sep=' ')))

Numba JIT キャッシュあり

Numba は、@njit(cache=True) とすることで、キャシュを有効にすることができます。キャッシュをコンパイルフェーズで作成する必要があるのですが、その方法は、2つあります。

一つは、signature を使って型指定をすることです。型指定があると最初に関数を読み込んだ時点で JIT が走ります。

下のサンプルでは、クロージャを使っていますが、クロージャを使うと@njitデコレータを1個だけつけるといいので手間が省けます。ここでは、必要がないので使っていませんが、nonlocal文を使うことも可能です。ただし、再帰関数の場合は、クロージャを使うことはできません。

import numpy as np
from numba import njit

@njit('(i8[:],)', cache=True)
def solve(stdin):
    N, K = stdin[:2]
    A = stdin[2: 2 + N]
    A = np.sort(A)[::-1]
    F = np.sort(stdin[2 + N:])

    def binary_search():
        nonlocal A
        lo = 0
        hi = 10 ** 12
        while lo < hi:
            mid = (lo + hi) // 2
            total_train = 0
            for i in range(A.shape[0]):
                need = A[i] - mid // F[i]
                total_train += need if need > 0 else 0
            if total_train <= K:
                hi = mid
            else:
                lo = mid + 1
            A += 1
        return lo

    return binary_search()

stdin = np.fromstring(open(0).read(), dtype=np.int64, sep=' ')
print(solve(stdin))

もう一つの方法は、コンパイルフェーズでサンプルデータを与えてコードを実行させることです。コンパイルフェーズの時の実行コマンドは'python3.8 {dirname}/{basename} ONLINE_JUDGE'なので、コマンドライン引数により区別が可能です。

import sys
import numpy as np
from numba import njit

@njit(cache=True)
def solve(A, F, K):
    lo = 0
    hi = 10 ** 12
    while lo < hi:
        mid = (lo + hi) // 2
        total_train = 0
        for i in range(A.shape[0]):
            need = A[i] - mid // F[i]
            total_train += need if need > 0 else 0
        if total_train <= K:
            hi = mid
        else:
            lo = mid + 1
    return lo

if sys.argv[-1] == 'ONLINE_JUDGE':
    s = """\
3 5
4 2 1
2 3 1
    """
else:
    s = open(0).read()
stdin = np.fromstring(s, dtype=np.int64, sep=' ')
N, K = stdin[:2]
A = stdin[2: 2 + N]
F = stdin[2 + N:]
print(solve(np.sort(A)[::-1], np.sort(F), K))

Numba AOT

Numba は、AOT も可能です。以下が公式マニュアルをみて、すぐに思いつくコードです。コマンドライン引数により、コンパイルフェーズかどうかが区別できることを利用しています。

import sys
import numpy as np

if sys.argv[-1] == 'ONLINE_JUDGE':
    from numba.pycc import CC
    cc = CC('my_module')

    @cc.export('solve', '(i8[:],i8[:],i8)')
    def solve(A, F, K):
        lo = 0
        hi = 10 ** 12
        while lo < hi:
            mid = (lo + hi) // 2
            total_train = 0
            for i in range(A.shape[0]):
                need = A[i] - mid // F[i]
                total_train += need if need > 0 else 0
            if total_train <= K:
                hi = mid
            else:
                lo = mid + 1
        return lo

    cc.compile()
    exit(0)

from my_module import solve
stdin = np.fromstring(open(0).read(), dtype=np.int64, sep=' ')
N, K = stdin[:2]
A = stdin[2: 2 + N]
F = stdin[2 + N:]
print(solve(np.sort(A)[::-1], np.sort(F), K))

デコレータを使わない書き方もできます。

import numpy as np

def solve(stdin):
    N, K = stdin[:2]
    A = stdin[2: 2 + N]
    A = np.sort(A)[::-1]
    F = np.sort(stdin[2 + N:])

    def binary_search():
        lo = 0
        hi = 10 ** 12
        while lo < hi:
            mid = (lo + hi) // 2
            total_train = 0
            for i in range(A.shape[0]):
                need = A[i] - mid // F[i]
                total_train += need if need > 0 else 0
            if total_train <= K:
                hi = mid
            else:
                lo = mid + 1
        return lo

    return binary_search()

def main():
    stdin = np.fromstring(open(0).read(), dtype=np.int64, sep=' ')
    print(solve(stdin))

def cc_export():
    from numba.pycc import CC
    cc = CC('my_module')
    cc.export('solve', '(i8[:],)')(solve)
    cc.compile()

if __name__ == '__main__':
    import sys
    if sys.argv[-1] == 'ONLINE_JUDGE':
        cc_export()
        exit(0)
    from my_module import solve
    main()

このコードは、一見複雑なようですが、main() 関数の後ろに AOT 関係のコードを追加するだけなので、スニペットに登録しておけば、実際の手間としてはキャッシュの場合とかわらなくなります。また、# from my_module import solve と1箇所コメントを付けるだけでデバッグができるようになるというのもメリットです。

サンプルデータを与えてコードを実行させて、型を自動で取得するコードも可能です。

import numpy as np

def solve(stdin):
    N, K = stdin[:2]
    A = stdin[2: 2 + N]
    A = np.sort(A)[::-1]
    F = np.sort(stdin[2 + N:])

    def binary_search():
        lo = 0
        hi = 10 ** 12
        while lo < hi:
            mid = (lo + hi) // 2
            total_train = 0
            for i in range(A.shape[0]):
                need = A[i] - mid // F[i]
                total_train += need if need > 0 else 0
            if total_train <= K:
                hi = mid
            else:
                lo = mid + 1
        return lo

    return binary_search()

def main(s, solve):
    stdin = np.fromstring(s, dtype=np.int64, sep=' ')
    print(solve(stdin))

def cc_export(s):
    from numba import njit
    from numba.pycc import CC
    solve_jit = njit(solve)
    main(s, solve_jit)
    cc = CC('my_module')
    cc.export('solve', solve_jit.nopython_signatures[0])(solve)
    cc.compile()

if __name__ == '__main__':
    import sys
    if sys.argv[-1] == 'ONLINE_JUDGE':
        s = """\
3 5
4 2 1
2 3 1
        """
        cc_export(s)
        exit(0)
    from my_module import solve
    s = open(0).read()
    main(s, solve)

このコードは、あまり実用的ではありませんが、テストデータを与えて一度実行させるとnopython_signaturesで型情報が取得できるので、個人用のツールを作成するのに使えると思います。これを利用したツールを、GitHub の atcoder-numba のページ
の方で公開しているのでそちらも参考にしてください。

Numba を使う場合の注意点

初めて Numba を使う人に簡単に注意点を少しだけ書いておきます。

Numba は、関数の頭に @jit を付けるだけで使えるので、簡単に使えると思うかもしれませんが、Python と Numpy の一部の型や関数しか使えないため、Python のように使い易いわけではありません。しばしば Numba がコンパイルできなくて、Python のインタープリタで実行する Object モードになってしまい、Numba を使っても速くならなかったということになってしまいます。

現時点では、データ型で数値型以外に事実上使える型は、Numpy の ndarray だけです。Numba を使う前に Numpy を学習しましょう。Numpy は、科学技術計算やデータサイエンスの分野では標準的なライブラリーとして使われており、GPU を使う場合でも 行列計算で PyTorch が Numpy 互換を採用し、Google の方も Jax で Numpy 互換を採用しています。だから、学習しても損になることはありません。Numpy がわかれば Numba の方はドキュメントの5分間ガイドをみればどうすればいいかは簡単に理解できると思います。

Numba の典型的な使い方は、Maspy さんのホームページにある[numpy] 2次元配列の高速化 のようなケースです。

そこに書いてあるように、numpyで簡潔に書ける処理はある程度限定されており、全ての計算をnumpyで書くのは難しい場合もあります。その場合に有効な手段になるのが Numba です。以下のように、2重ループによる実装をしても大丈夫で、Python のコードよりも 100 倍ぐらい速くなります。

import numpy as np
from numba import njit

@njit
def calc_comb(comb):
    comb[0][0] = 1
    for n in range(1,N+1):
        comb[n][0] = 1
        for k in range(1,N+1):
            comb[n][k] = (comb[n-1][k-1] + comb[n-1][k]) % MOD

N = 10000
MOD = 10 ** 10
comb = np.zeros((N + 1, N + 1), dtype=np.int64)
calc_comb(comb)

このコードの中で@jitではなくて@njitを使っているのは、Object モードになった場合にエラーにするためです。Object モードの場合には毎回コンパイルが走るので、AtCoder の場合 Object モードになって速くなるケースは考えられないので、@njitを使うようにした方がいいです。なお、@njitは、@jit(nopython=True)の省略形で、Numba のドキュメントでも使用を勧めています。

また、Maspy さんのホームページは、競技プログラミングで numpy や scipy の使い方がわかる数少ない情報の一つなので参考にするといいと思います。

numba でも、最近 @jitclass, str, list, Typed Dict が使えるようになっています。しかし、まだまだ問題が多いので初心者は使わない方がいいと思います。@jitclass を使うとコードが長くなるのが普通ですが、キャッシュも AOT もできません。そのため、AtCoder では、TLE が心配で使いずらいです。(2020-12-13追記 @jitclass でキャッシュも AOT もできます。まだ、問題もありますが使えます。) また、list や dict を始めとして動的データ構造用のツールが全滅状態です(2020-12-13追記 ver0.52.0 で list が高速になっています。「numba の Typed List が速くなった」に書きました。次回の Atcoder の言語アップデートには期待できると思います。) 。そのため、自作ライブラリを用意しておくといいいと思います。

AtCoder では、制約から必要な領域の最大を計算することが殆どの場合に可能なので、事前に、必要な領域を持った ndarray を作っておけばいいので、自作ライブラリを作ること自体は、それほど難しくありません。

自作ライブラリを作る場合、jitclass が使いづらいというのが難点ですが、取り敢えずは以下の優先度付きキューのサンプルようにクロージャを使うことを考えています。ローカル環境にない変数へのアクセスが可能なので、最低限の隠ぺいは可能です。

@njit
def heapq_numba(a):
    """
    優先度付きキュー。最大の数値を取り出す。
    ----------
    HEAP_MAX_SIZE: queの最大サイス
    _heapque: キュー用に必要なサイズを確保した ndarray
    _heapsz: 現時点のqueのサイズ
    _heapempty: queが空の場合にpopが返す値
    """
    HEAP_MAX_SIZE = 100
    _heapque = np.empty(HEAP_MAX_SIZE, dtype=np.int64)
    _heapsz = 0
    _heapempty = -1

    def heappush(val):
        nonlocal _heapsz
        i = _heapsz
        _heapsz += 1
        while i > 0:
            parent = (i - 1) // 2
            if _heapque[parent] < val:
                _heapque[i] = _heapque[parent]
                i = parent
            else:
                _heapque[i] = val
                break
        else:
            _heapque[0] = val

    def heappop():
        nonlocal _heapsz
        if _heapsz == 0:
            return _heapempty
        ret = _heapque[0]
        _heapsz -= 1
        val = _heapque[_heapsz]
        i = 0
        while True:
            child = 2 * i + 1
            if child >= _heapsz:
                _heapque[i] = val
                break
            if _heapque[child] > _heapque[child + 1]:
                if _heapque[child] > val:
                    _heapque[i] = _heapque[child]
                    i = child
                else:
                    _heapque[i] = val
                    break
            else:
                if _heapque[child + 1] > val:
                    _heapque[i] = _heapque[child + 1]
                    i = child + 1
                else:
                    _heapque[i] = val
                    break
        return ret

    res = np.empty_like(a)
    for c in a:
        heappush(c)
    for i in range(a.shape[0]):
        res[i] = heappop()
    return res

この優先度付きキューのサンプルコードでは、整数型しか使えませんが、タプルやリストを使いたい場合も多いと思います。現状では、コードを少し修正する必要があります。こういうことを考えると、C++ のテンプレートは便利ということになります。

Python で競プロをすることについて

今回の AtCoder の言語アップデートで、Numba の キャッシュ及び AOT の提案に関しては、Twitter での黒木玄氏と高橋直大氏との間で行われた以下の Julia の実行時間に関する議論が参考になりました。コンパイル時に「JITコンパイル結果を残しておく」という方法で、Julia や Ruby 等の JIT コンパイラーでも使えるので、AtCoderの方で今回のジャッジシステムのアップデートで対応していただけたと思っています。

https://twitter.com/genkuroki/status/1141270586154311683
https://twitter.com/chokudai/status/1141244427333066752

Python で競プロをすることについては、いろいろな意見がありますが、結局は、それぞれのオンラインコンテストのルールと問題次第です。Python は、使いやすくて優秀なライブラリーが揃った言語です。特に GPU が使える問題と環境であれば、圧倒的な強みを発揮します。一方で、ライブラリーが使えないと単に使いやすいだけの遅い言語になります。

そのため、Codeforces では、全く使いものになりません。Python で提出しようとすると「Almost always, if you send a solution on PyPy, it works much faster」と表示されます。問題も時間制限が厳しいので PyPy でも厳しいと思います。一方で、競プロとは違うという人もいますが Kaggle のコンペでは、8割以上の人が Python を使っています。また、プログラミングコンテスト「ICFP-PC」は、計算環境制限なしの72時間耐久コンテストですが、そこでも主要なプログラミング言語の一つになっています。

AtCoder の場合は、Codeforces ほどはガチ競プロではないので、Python でも十分戦えます。昨年、maspyさんが橙(2400+)になり、Python/PyPyで、これまでの rated contest で、難易度2800以下のものを全てACさせて実証しました([AtCoder] 橙(2400+)になりました参照)。Maspyさんの頑張りのおかげで Python に不足していたアルゴリズム等の知見が増えたというのも大きいです。

こう考えていくと、Python は C++ と相補的な関係にあることがよくわかると思います。Python と C++ は、両方とも使いこなせれるのが理想なのですが、そのようになるだけの時間がないのが普通なので、自分にあった方から学習すればいいと思います。両者は特徴がハッキリしているので、どちらを選択したらいいかはわかりやすいと思います。

スクリプト言語の中では、機能及び情報の豊富さの両面から判断すると、Python は競プロに最も適した言語というのは間違いないでしょう。学習の容易さも考慮すれば、プログラムの初心者には、コンパイラー言語も含めて最も適した言語といえると思います。PyPy3 は、初心者にとって、学習の容易さと処理速度のバランスが取れたいい選択肢になります。その後、データサイエンス・機械学習等で Numpy が必要な人であれば Numba を、それ以外の人であれば Cython を使うことで上位まで行くことが可能です。また、アルゴリズムの世界で食っていこうと思うのであれば C++ を勉強すればすればいいです。Python で使っているライブラリで最近開発された TensorFlow, PyTorch, Open3D, Dlib, pyrealsense, apache arrow 等は少なくともコア部分は C++ で開発されています。でも、実際にそれらのライブラリを使う場合は C++ ではなくて、python を使うので Python を学習したことが無駄にはなりません。

他のスクリプト言語についても悲観する必要はありません。最近は、スクリプト言語も十分に高速になってきています。

JavaScript の場合は、今回の言語アップデートで機能的には大きく改善されて、PyPy よりもずっと速い言語になっています。また、altJS の Dart で AOT が使えるようになっています。ただし、UI 関係を中心に使われてきたため、アルゴリズム関係のライブラリはあまり整備されていません。また、競プロだけに限っていうとユーザーがあまりいないので、競プロに関する情報や知見が少ないのが欠点です。maspy さんのような人が現れると、大きく伸びるのでないでしょうか。

Ruby で競プロをしていた人が書いたブログ「スクリプト言語などで競プロをすることについて」によると、Rubyは遅くて「通せない問題があることを受け入れる」必要があり、青色まで到達するのが怪しい言語だそうです。しかし、Ruby にも Oracle が開発した TruffleRuby という高速な実装があったりするので、高速化は容易なはずです。

参考

ベンチマークに使用したコード

Numba (AOT)

Main.py
import numpy as np
import sys


if sys.argv[-1] == 'ONLINE_JUDGE':
    from numba.pycc import CC;
    cc = CC('udf')
    
    @cc.export('calc', 'void(i4[:,:])')
    def calc(a):
        # 横方向
        for j in range(a.shape[0] - 1):
            prev = 0
            start = 0
            for i in range(a.shape[1]):
                if a[j, i] > 0:
                    if prev == 0:
                        start = i
                        prev = 1
                else:
                    if prev == 1:
                        num = i - start
                        for k in range(start, i):
                            a[j, k] = num
                        prev = 0
        # 縦方向
        for i in range(a.shape[1] - 1):
            prev = 0
            start = 0
            for j in range(a.shape[0]):
                if a[j, i] > 0:
                    if prev == 0:
                        start = j
                        prev = 1
                else:
                    if prev == 1:
                        num = j - start - 1
                        for k in range(start, j):
                            a[k, i] += num
                        prev = 0
    cc.compile()


def main():
    cin = sys.stdin.buffer.read().split(maxsplit=2)
    H = int(cin[0])
    W = int(cin[1])
    a = np.empty((H + 1, W + 1), dtype='i4')
    a[:H, :] = (np.frombuffer(cin[2][:H * (W + 1)], dtype='B') == ord('.')).reshape((H, W + 1))
    a[H] = np.zeros(W+1, dtype='i4')
    calc(a)
    print(a.max())


if __name__ == "__main__":
    from udf import calc
    main()

Numba(JIT)

キャッシュ有りの場合は、@njit@njit('void(i4[:,:])', cache=True)に修正

Main.py
import numpy as np
import sys
from numba import njit

@njit
def calc(a):
    # 横方向
    for j in range(a.shape[0] - 1):
        prev = 0
        start = 0
        for i in range(a.shape[1]):
            if a[j, i] > 0:
                if prev == 0:
                    start = i
                    prev = 1
            else:
                if prev == 1:
                    num = i - start
                    for k in range(start, i):
                        a[j, k] = num
                    prev = 0
    # 縦方向
    for i in range(a.shape[1] - 1):
        prev = 0
        start = 0
        for j in range(a.shape[0]):
            if a[j, i] > 0:
                if prev == 0:
                    start = j
                    prev = 1
            else:
                if prev == 1:
                    num = j - start - 1
                    for k in range(start, j):
                        a[k, i] += num
                    prev = 0


cin = sys.stdin.buffer
H, W = map(int, cin.readline().split())
a = np.empty((H + 1, W + 1), dtype='i4')
a[:H, :] = (np.frombuffer(cin.read(H * (W + 1)), dtype='B') == ord('.')).reshape((H, W + 1))
a[H] = np.zeros(W + 1, dtype='i4')

calc(a)
print(a.max())

Cython のコード

Main.pyx
import numpy as np
import sys

cdef void calc(int[:,:] a):
    cdef int i, j, k, prev, start, num;
    # 横方向
    for j in range(a.shape[0] - 1):
        prev = 0
        start = 0
        for i in range(a.shape[1]):
            if a[j, i] > 0:
                if prev == 0:
                    start = i
                    prev = 1
            else:
                if prev == 1:
                    num = i - start
                    for k in range(start, i):
                        a[j, k] = num
                    prev = 0
    # 縦方向
    for i in range(a.shape[1] - 1):
        prev = 0
        start = 0
        for j in range(a.shape[0]):
            if a[j, i] > 0:
                if prev == 0:
                    start = j
                    prev = 1
            else:
                if prev == 1:
                    num = j - start - 1
                    for k in range(start, j):
                        a[k, i] += num
                    prev = 0


def main():
    cin = sys.stdin.buffer
    H, W = map(int, cin.readline().split())
    a = np.empty((H + 1, W + 1), dtype='i4')
    a[:H, :] = (np.frombuffer(cin.read(H * (W + 1)), dtype='B') == ord('.')).reshape((H, W + 1))
    a[H] = np.zeros(W+1, dtype='i4')

    calc(a)
    print(a.max())


if __name__ == "__main__":
    main()

コンパイルコマンド及び実行コマンド

cython -3 --embed Main.pyx
gcc -O2 -I /usr/include/python3.7m Main.c -lpython3.7m
./a.out

C コンパイラーのことはよくわからないので、下記の資料を参考にしてください。

https://github.com/cython/cython/wiki/EmbeddingCython
https://github.com/cython/cython/blob/master/Demos/embed/Makefile
https://github.com/cython/cython/wiki/FAQ#how-can-i-make-a-standalone-binary-from-a-python-program-using-cython

Pythran のコード

main.py
#pythran export main2()
import numpy as np

def calc(a):
    for j in range(a.shape[0] - 1):
        prev = 0
        start = 0
        for i in range(a.shape[1]):
            if a[j, i] > 0:
                if prev == 0:
                    start = i
                    prev = 1
            else:
                if prev == 1:
                    num = i - start
                    for k in range(start, i):
                        a[j, k] = num
                    prev = 0
    # 縦方向
    for i in range(a.shape[1] - 1):
        prev = 0
        start = 0
        for j in range(a.shape[0]):
            if a[j, i] > 0:
                if prev == 0:
                    start = j
                    prev = 1
            else:
                if prev == 1:
                    num = j - start
                    for k in range(start, j):
                        a[k, i] += num - 1
                    prev = 0


def main2():
    stdin = open('/dev/stdin')
    H, W = map(int, stdin.readline().split())
    a = np.zeros((H+1, W+1), dtype=np.int64)
    s = stdin.read(H *(W + 1))
    n = 0
    for i in range(H):
        for j in range(W + 1):
            if ord(s[n]) == 46:
                a[i, j] = 1
            n += 1
    calc(a)
    print(a.max())

コンパイルコマンド及び実行コマンド

printf '#include \"main.hpp\"\nusing namespace __pythran_main ; int main() { main2()(); return 0 ; }' > main.cpp;
pythran -e main.py -o main.hpp &&`pythran-config --compiler --cflags` -std=c++17 -O3 -DUSE_XSIMD -march=native main.cpp -o main
./main

詳しくは、公式マニュアルの Command Line Interface のページをみてください。
自分でテストする場合は、実行可能ファイルにする必要はないので、以下のコマンドの方が簡単です。

pythran main.py
python -C 'import main;main.main2()'

テストデータの作成コード

mysample.py
def abc129_d(H, W):
    with open('sample.txt', 'w') as f:
        f.write(f'{H} {W}\n')
        a = np.random.choice([b'#', b'.'], (H, W + 1))
        a[:, -1] = b'\n'
        a.tofile(f, sep='')

abc129_d(2000, 2000)

その他のコードは、AtCoder の提出一覧から拾えるので、ここでは省略します。

88
80
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
88
80

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?