タイトルのほか特に言うことはないのですが、参考までにどういう経緯でどういうことをしたのかを書いていきます。
背景
現在気象予測に関して機械学習を取り入れてみよう、みたいな研究をしているのですが、それなりにお得に使えるColab Pro+(月額5000円くらい)とGoogle One ベーシック(Driveの保存容量が100GBになる, 月額250円)を組み合わせて使っていました。
気象情報はこんな感じでYYmmddHH.grib2というファイル名を/lab_data/data下にバーッとならべています。Google DriveをColabにマウントした後、これらを一つ一つpygribというライブラリで読み込んでいくと、、
遅え…1
というわけでこれをなんとか解決しようという話です。
問題の詳細、問題設定
まず前提として
- 一ファイルあたりの大きさはほぼ1MB
- 今回読み込みたいファイルの数は736個2
です。なので、全ての通信すべきデータ量としては750MBほどです。あんまり厳密ではないですが、大目に見てください。
例えば1MB/sくらいの速度でダウンロードしたなら12分半ほどかかります。この場合、初めてcolabのコンピューティングインスタンスを立ち上げて一度Driveから読み込むときは必ず12分半ほどかかるというわけです。そのインスタンスを使うならば二度目以降の読み込みはキャッシュが効いて高速になりますが、それでも毎回立ち上げ時に待ってはいられないですね、、
方針
2つ方針を考えました:
- 並行処理によってDriveから読み込む
- ファイルたちを適切な粒度に分割しzipにまとめ、それをダウンロードして解凍
今回は1つめの方針を採用しました。なぜなら、2つめの方針が面倒くさそうだったからです。
適切な粒度というのがどれほどなのかわかりませんし、もし追加でファイルが追加されたらその都度圧縮しなければいけないですし、、
2つめがどれほど効果があるのかは要検証です。どなたか検証してください
実装
以下ではMulti processingとSingle processingで取得時間を比較し、更に本当に同じデータが取得できているか確かめるためにcalc関数で取得したデータの総和を取っています3。
importなどの用意
!pip install pygrib
!pip install wget
import pygrib
from datetime import datetime, timedelta
import multiprocessing
import wget
import os, sys
import subprocess
import tqdm
import time
import numpy as np
from google.colab import drive
drive.mount('/content/drive')
データを正しく取得できているか判定するためのcalc関数
def calc(data_ll):
total = 0
for data_l in data_ll:
for data in data_l:
rho, u_vert, u_hori = data
total += int(np.sum(rho)) + int(np.sum(u_vert)) + int(np.sum(u_hori))
return total
Multi Processing
multiprocessingパッケージを使います。スレッドのやり取りにはmultiprocessing.Queue
を使いました。
workerは8つ作り、それぞれで92個のファイルをダウンロードしてきましょう。一般的にColabインスタンスのCPUのプロセス数は4らしいですが、ここでネックになるのはCPU I/OではなくNetwork I/Oなのでworker数はそれより多くても問題ないと思います(多分)。
それではまずはworkerのクラスを作りましょう。ここでは無秩序にqueueにデータを放り込んでいますが、もし順番などが重要になってくるならdict形式でどのworkerから帰ってきたデータなのかをマーキングしてQueueに入れればいいと思います。
class Worker:
def __init__(self, num: int, dts: list[datetime], que: multiprocessing.Queue):
self.num = num
self.dts = dts
self.que = que
def work(self):
ret = []
t = time.time()
print(f"start process {self.num}\n")
for dt in self.dts:
with pygrib.open(f'/content/drive/MyDrive/lab_data/data/{dt:%Y%m%d%H}.grib2') as grib:
rho = np.array(grib.select()[0].values, dtype=np.float32)
u_hori = np.array(grib.select()[1].values, dtype=np.float32)
u_vert = np.array(grib.select()[2].values, dtype=np.float32)
ret.append((rho.copy(), u_hori.copy(), u_vert.copy()))
print(f"download complete in process {self.num}, {time.time() - t}s\n")
self.que.put(ret)
print(f"end process {self.num}, {time.time() - t:.3f}s\n")
あとは実際にmultiprocessing.Process()
にこのworkerを渡してやって、メインスレッドでqueueからデータを取り出せばよいです。
def multi_data_load():
core = 8
que = multiprocessing.Queue()
dts_list = [[datetime(2011 + i, 7, 1, 0) + timedelta(days = j) for j in range(92)] for i in range(core)]
workers = [Worker(i, dts_list[i], que) for i in range(core)]
res = []
processes = [multiprocessing.Process(target=worker.work, args=()) for worker in workers]
print("start processes")
for process in processes:
process.start()
for _ in range(len(processes)):
res.append(que.get())
for process in processes:
process.join()
print("end processes")
return calc(res)
Single Processing
上のworkを単純にメインスレッドに書くだけです。
def single_data_load():
core = 8
dts_list = [[datetime(2011 + i, 7, 1, 0) + timedelta(days = j) for j in range(92)] for i in range(core)]
res = []
for dts in dts_list:
ret = []
for dt in dts:
with pygrib.open(f'/content/drive/MyDrive/lab_data/data/{dt:%Y%m%d%H}.grib2') as grib:
rho = np.array(grib.select()[0].values, dtype=np.float32)
u_hori = np.array(grib.select()[1].values, dtype=np.float32)
u_vert = np.array(grib.select()[2].values, dtype=np.float32)
ret.append((rho.copy(), u_hori.copy(), u_vert.copy()))
res.append(ret)
return calc(res)
計測結果
ともに計測をする前にインスタンスを破棄してキャッシュが効いていない状態で実験しました。
まずはMulti processingの結果からです
%time multi_data_load()
start processes
start process 0
start process 1
start process 2
start process 4
start process 3
start process 5
start process 6
start process 7
download complete in process 7, 108.31141257286072s
end process 7, 108.340s
download complete in process 3, 116.9851770401001s
end process 3, 116.992s
download complete in process 1, 118.3642942905426s
download complete in process 5, 118.3714051246643s
end process 5, 118.399s
download complete in process 4, 118.48550724983215s
end process 4, 118.502s
download complete in process 6, 118.90653777122498s
end process 6, 118.943s
end process 1, 118.373s
download complete in process 0, 119.97895193099976s
end process 0, 119.992s
download complete in process 2, 121.08988094329834s
end process 2, 121.111s
end processes
CPU times: user 1.51 s, sys: 4.65 s, total: 6.16 s
Wall time: 2min 5s
18063118171031
次にSingle Processingの結果です。
%time single_data_load()
CPU times: user 25.9 s, sys: 7.24 s, total: 33.2 s
Wall time: 8min 14s
18063118171031
まとめると、
- Multi Processing
- 計125秒で終了
- データの合計は18063118171031
- Single Processing
- 計494秒で終了
- データの合計は18063118171031
という感じですね。細かいことを言うとMemory I/OやCPU I/Oなどもあるので単純比較はできないのですが、だいたい4倍くらい早くなりました。
まとめ、考察
あんまりこういうアプローチをしている記事がなかったので書いてみました。
Single Processingだとあんまり速度が出ない理由はプロセス単位でNetwork I/Oの速度制限が掛けられているから、とか、Google Driveの1 Socketあたりの速度制限があるから、とかあるんでしょうか…?読み込みが二回目以降になるとキャッシュが効いて速度がめちゃくちゃ上がることから鑑みるにpygrib.openの速度の問題とかではなくやはりネットワークの速度の問題な気がします。