Rayとは
RayはPythonにおける分散並列処理を高速かつシンプルに書けるフレームワークで, 既存のコードを並列化することも容易な設計となっています.
Rayを使うことでmultiprocessingなどに比べ簡単にプロセスレベルの並列処理を記述することができます.
本記事はRayチュートリアルの内容をもとにしており,
コードはPython 3.8.2, Ray 0.8.4での動作を確認しています.
インストール
ターミナルでpipなどからインストールできます.
$ pip install ray
使い方
基本的な用途としては覚える文法はray.init
ray.remote
ray.get
の3つのみで, この記事では加えてray.wait
ray.put
も紹介します.
Rayによる並列化の基本
実行に3秒かかる関数 func
が二度呼び出され全体の実行に6秒かかる以下のコードについて, func
の実行を並列化することを考えましょう.
import time
def func(x):
time.sleep(3)
return x
begin_time = time.time() # 開始時刻を記録
res1, res2 = func(1), func(2) # funcを2度呼ぶ
print(res1, res2) # 出力: 1 2
end_time = time.time() # 終了時刻を記録
print(end_time - begin_time) # 6秒ぐらい
Rayを使う場合には 必ず最初に ray.init
で使用するリソース数の指定などを行いRayのプロセスを起動する必要があります.
import ray
# ray.init() のように明示的に指定しなかった場合自動的にリソース数が決定されます
ray.init(num_cpus=4)
# 時間計測をより正確にする都合上Rayの起動を少し待つ
time.sleep(1)
ある関数を並列で実行させたい場合, その関数をRayの扱える remote関数 にする必要があります.
といってもやり方は簡単で, その関数に@ray.remote
とデコレーターをつけるだけです.
remote関数は(関数名).remote(引数)
として呼び出すとRayの並列ワーカーに送られて実行されます.
.remote(引数)
は終了を待たずに Object ID というものをreturnします.
@ray.remote
def func(x):
time.sleep(3)
return x
begin_time = time.time()
res1, res2 = func.remote(1), func.remote(2)
print(res1) # 出力例: ObjectID(45b9....)
結果を取得したい場合には, remote関数から返ってきたObject IDをray.get
に渡してあげればよいです.
ray.get
はObject IDに対応する結果がすべて取得できるまでブロッキングします.
print(ray.get(res1), ray.get(res2)) # 出力: 1 2
# ray.getはリストを受けとることもできる
print(ray.get([res1, res2])) # 出力: [1, 2]
end_time = time.time()
print(end_time - begin_time) # 3秒ぐらい
以上のコードを1つのスクリプトとして実行すると3秒程度しかかかっておらず, func
の実行が並列化されていることがわかります.
基本はこれだけです.
依存関係のある並列化
Rayはremote関数間に依存関係があっても, Object IDをそのまま受け渡すことで処理が可能です.
受け渡されたObject IDは実際に実行される際には通常のPythonオブジェクトに復元されて実行されます.
以下の例では, vec
内の4つの各要素に対してfunc1
とfunc2
を順に適用しています. 1要素の処理には2秒かかります.
※これ以降の例では時間計測のためのコードを省略しています
@ray.remote
def func1(x):
time.sleep(1)
return x * 2
@ray.remote
def func2(x):
time.sleep(1)
return x + 1
vec = [1, 2, 3, 4]
results = []
for x in vec:
res = func1.remote(x) # resにはObjectIDが入っている
res = func2.remote(res) # ObjectIDをそのまま次のremote関数に渡す
results.append(res)
# resultsはObjectIDのリスト
print(ray.get(results)) # 出力: [3, 5, 7, 9]
Rayは依存関係を解析し, 依存先のない func1
を先に並列実行し,その後 func1
の処理の終わった要素について func2
を並列実行します.
逐次では8秒かかるこの処理は並列化により2秒程度で実行されます.
また, Rayはネストされた呼び出しにも対応しており, func2
を次のように書き換えても問題なく動作します.
ネスト呼び出しの条件は, 呼び出したい関数が事前に定義されていることだけです.
@ray.remote
def func2(x):
x = func1.remote(x) # ObjectIDが返される
time.sleep(1)
return ray.get(x) + 1 # ObjectIDと直接足し算は出来ないため, ray.getしてから計算する
print(ray.get([func2.remote(x) for x in vec])) # 出力: [3, 5, 7, 9]
私の環境での実測値は2秒より少し遅くなりましたが, 8秒よりは速く並列に実行できています.
Actor
remote関数は実行されたあとそのままreturnしてしまい状態を持つことができません.
状態をもつような処理を, Rayではクラスを@ray.remote
で修飾することにより実現します.
@ray.remote
で修飾されたクラスを Actor と呼びます.
例えば, 次のような一度のインクリメントにつき1秒かかるカウンターを考えましょう.
Actorのインスタンスを作る時も, 関数呼び出しのときと同様 .remote()
を付けます.
@ray.remote
class Counter:
def __init__(self, init_val, sleep=True):
# カウンターをinit_valで初期化
self.count = init_val
self.sleep = sleep
def increment(self):
if self.sleep:
time.sleep(1)
self.count += 1
return self.count
# 初期値0と100のカウンターを作る
counter1, counter2 = Counter.remote(0), Counter.remote(100)
それぞれのカウンターを3回ずつインクリメントしながら, 各段階での値をresultsに記録していきましょう.
results = []
for _ in range(3):
results.append(counter1.increment.remote())
results.append(counter2.increment.remote())
print(ray.get(results)) # 出力: [1, 101, 2, 102, 3, 103]
合計6回インクリメントがされていますが, カウンターごとに並列化されているので3秒しかかからずに値を取得することができます.
また, Actorの同一のインスタンスのメソッドを並列に呼び出したいときには, Actorのインスタンスを引数にとるremote関数を定義すればよいです.
例えば次のように, 1秒おきに increment
を呼び出す incrementer
という関数を0.5秒ずらして実行させてみましょう.
ここではincrement
自体が一瞬で終わるようなCounter
を用意していることに注意してください.
@ray.remote
def incrementer(counter, id, times):
# 1秒おきにtimes回インクリメントを行う
for _ in range(times):
cnt = counter.increment.remote()
print(f'id= {id} : count = {ray.get(cnt)}')
time.sleep(1)
counter = Counter.remote(0, sleep=False) # 1回のインクリメントが一瞬で終わるカウンター
incrementer.remote(counter, id=1, times=5)
time.sleep(0.5) # 開始を0.5秒ずらす
inc = incrementer.remote(counter, id=2, times=5)
ray.wait([inc]) # 次に説明する, 終了を待つ関数
実行すると, 次のように incrementer
がcounter
の値を0.5秒おきに交互に更新している様子がわかります.
(0.0秒後) id = 1 : count = 1
(0.5秒後) id = 2 : count = 2
(1.0秒後) id = 1 : count = 3
(1.5秒後) id = 2 : count = 4
(2.0秒後) ......
ray.wait
並列実行されているObject IDのリストを ray.get
に渡すと, そのすべての実行が終了するまで値を取得できません.
ray.wait
を使うと, 並列実行された関数のうち指定した数が終了するまで待機し, その時点で終了したIDとそうでないIDを別々にreturnしてくれます.
@ray.remote
def sleep(x):
# x秒休んでxを返す関数
time.sleep(x)
return x
ids = [sleep.remote(3), sleep.remote(5), sleep.remote(2)]
finished_ids, running_ids = ray.wait(ids, num_returns=2, timeout=None)
print(ray.get(finished_ids)) # 出力(3秒経過時点): [3,2]
print(ray.get(running_ids)) # 出力(5秒経過時点): [5]
ray.put
実は, remote
関数に渡されたオブジェクトは, そのつど暗黙裏にシリアライズされてRayの共有メモリ上にコピーされます.
そのため, 巨大なオブジェクトを remote
の引数に複数回渡してしまうと余計にコピーするための時間がかかるほか, 共有メモリ上の領域を無駄に消費してしまいます.
このような場合には, ray.put
を用いて事前に一度だけ明示的にコピーを行うことによりこの無駄を回避することができます.
ray.put
は remote
と同様Object IDを返し, これをremote関数に渡してあげればよいです.
一度コピーされたオブジェクトは共有されているので, 並列実行するワーカーはどれもこのオブジェクトを参照することができます.
@ray.remote
def func4(obj, idx):
time.sleep(1)
return idx
# big_object はサイズの大きなobjectだとする
big_object = None
big_obj_id = ray.put(big_object)
# func.remote()が4回呼ばれるが, いま渡しているのはObjectIDのため再度big_objectのコピーは発生しない
results = [func4.remote(big_obj_id, i) for i in range(4)]
print(ray.get(results)) # 出力: [0, 1, 2, 3]
なお, Rayのray.get
によるデシリアライズは pickle.load
に比べて非常に高速であるようです.
おわりに
公式ドキュメントにはより詳細な使い方が載っています.
特にExamplesには分散環境でのパラメータサーバーや強化学習などの具体的な用例が載っていて参考になるでしょう.
またRayを基盤とした高レベルなフレームワークも用意されており, 強化学習向けのRLlibやハイパーパラメータチューニング向けのTuneなどがあります.
是非Rayを使って快適な並列処理ライフを手に入れましょう.