4
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Google DriveからColabに大量のデータを引っ張ってくるとき遅すぎる問題を並列処理で解決した話

Last updated at Posted at 2023-03-19

タイトルのほか特に言うことはないのですが、参考までにどういう経緯でどういうことをしたのかを書いていきます。

背景

現在気象予測に関して機械学習を取り入れてみよう、みたいな研究をしているのですが、それなりにお得に使えるColab Pro+(月額5000円くらい)とGoogle One ベーシック(Driveの保存容量が100GBになる, 月額250円)を組み合わせて使っていました。

image.png

気象情報はこんな感じでYYmmddHH.grib2というファイル名を/lab_data/data下にバーッとならべています。Google DriveをColabにマウントした後、これらを一つ一つpygribというライブラリで読み込んでいくと、、

image.png

遅え…1

というわけでこれをなんとか解決しようという話です。

問題の詳細、問題設定

まず前提として

  • 一ファイルあたりの大きさはほぼ1MB
  • 今回読み込みたいファイルの数は736個2

です。なので、全ての通信すべきデータ量としては750MBほどです。あんまり厳密ではないですが、大目に見てください。

例えば1MB/sくらいの速度でダウンロードしたなら12分半ほどかかります。この場合、初めてcolabのコンピューティングインスタンスを立ち上げて一度Driveから読み込むときは必ず12分半ほどかかるというわけです。そのインスタンスを使うならば二度目以降の読み込みはキャッシュが効いて高速になりますが、それでも毎回立ち上げ時に待ってはいられないですね、、

方針

2つ方針を考えました:

  1. 並行処理によってDriveから読み込む
  2. ファイルたちを適切な粒度に分割しzipにまとめ、それをダウンロードして解凍

今回は1つめの方針を採用しました。なぜなら、2つめの方針が面倒くさそうだったからです。
適切な粒度というのがどれほどなのかわかりませんし、もし追加でファイルが追加されたらその都度圧縮しなければいけないですし、、
2つめがどれほど効果があるのかは要検証です。どなたか検証してください

実装

以下ではMulti processingとSingle processingで取得時間を比較し、更に本当に同じデータが取得できているか確かめるためにcalc関数で取得したデータの総和を取っています3

importなどの用意
!pip install pygrib
!pip install wget
import pygrib
from datetime import datetime, timedelta
import multiprocessing
import wget
import os, sys
import subprocess
import tqdm
import time
import numpy as np
from google.colab import drive
drive.mount('/content/drive')
データを正しく取得できているか判定するためのcalc関数
def calc(data_ll):
    total = 0
    for data_l in data_ll:
        for data in data_l:
            rho, u_vert, u_hori = data
            total += int(np.sum(rho)) + int(np.sum(u_vert)) + int(np.sum(u_hori))

    return total

Multi Processing

multiprocessingパッケージを使います。スレッドのやり取りにはmultiprocessing.Queueを使いました。
workerは8つ作り、それぞれで92個のファイルをダウンロードしてきましょう。一般的にColabインスタンスのCPUのプロセス数は4らしいですが、ここでネックになるのはCPU I/OではなくNetwork I/Oなのでworker数はそれより多くても問題ないと思います(多分)。

それではまずはworkerのクラスを作りましょう。ここでは無秩序にqueueにデータを放り込んでいますが、もし順番などが重要になってくるならdict形式でどのworkerから帰ってきたデータなのかをマーキングしてQueueに入れればいいと思います。

class Worker:
    def __init__(self, num: int, dts: list[datetime], que: multiprocessing.Queue):
        self.num = num
        self.dts = dts
        self.que = que

    def work(self):
        ret = []
        t = time.time()
        print(f"start process {self.num}\n")

        for dt in self.dts:
            with pygrib.open(f'/content/drive/MyDrive/lab_data/data/{dt:%Y%m%d%H}.grib2') as grib:
                rho = np.array(grib.select()[0].values, dtype=np.float32)
                u_hori = np.array(grib.select()[1].values, dtype=np.float32)
                u_vert = np.array(grib.select()[2].values, dtype=np.float32)
                ret.append((rho.copy(), u_hori.copy(), u_vert.copy()))
        
        print(f"download complete in process {self.num}, {time.time() - t}s\n")

        self.que.put(ret)
        print(f"end process {self.num}, {time.time() - t:.3f}s\n")

あとは実際にmultiprocessing.Process()にこのworkerを渡してやって、メインスレッドでqueueからデータを取り出せばよいです。

def multi_data_load():
    core = 8
    que = multiprocessing.Queue()

    dts_list = [[datetime(2011 + i, 7, 1, 0) + timedelta(days = j) for j in range(92)] for i in range(core)]
    workers = [Worker(i, dts_list[i], que) for i in range(core)]
    res = []

    processes = [multiprocessing.Process(target=worker.work, args=()) for worker in workers]

    print("start processes")

    for process in processes:
        process.start()

    for _ in range(len(processes)):
        res.append(que.get())

    for process in processes:
        process.join()
    
    print("end processes")

    return calc(res)

Single Processing

上のworkを単純にメインスレッドに書くだけです。

def single_data_load():
    core = 8
    dts_list = [[datetime(2011 + i, 7, 1, 0) + timedelta(days = j) for j in range(92)] for i in range(core)]
    res = []

    for dts in dts_list:
        ret = []
        for dt in dts:
            with pygrib.open(f'/content/drive/MyDrive/lab_data/data/{dt:%Y%m%d%H}.grib2') as grib:
                rho = np.array(grib.select()[0].values, dtype=np.float32)
                u_hori = np.array(grib.select()[1].values, dtype=np.float32)
                u_vert = np.array(grib.select()[2].values, dtype=np.float32)
                ret.append((rho.copy(), u_hori.copy(), u_vert.copy()))
        res.append(ret)

    return calc(res)

計測結果

ともに計測をする前にインスタンスを破棄してキャッシュが効いていない状態で実験しました。
まずはMulti processingの結果からです

%time multi_data_load()

start processes
start process 0
start process 1
start process 2
start process 4
start process 3
start process 5
start process 6
start process 7
download complete in process 7, 108.31141257286072s
end process 7, 108.340s
download complete in process 3, 116.9851770401001s
end process 3, 116.992s
download complete in process 1, 118.3642942905426s
download complete in process 5, 118.3714051246643s
end process 5, 118.399s
download complete in process 4, 118.48550724983215s
end process 4, 118.502s
download complete in process 6, 118.90653777122498s
end process 6, 118.943s
end process 1, 118.373s
download complete in process 0, 119.97895193099976s
end process 0, 119.992s
download complete in process 2, 121.08988094329834s
end process 2, 121.111s
end processes
CPU times: user 1.51 s, sys: 4.65 s, total: 6.16 s
Wall time: 2min 5s
18063118171031

次にSingle Processingの結果です。

%time single_data_load()

CPU times: user 25.9 s, sys: 7.24 s, total: 33.2 s
Wall time: 8min 14s
18063118171031

まとめると、

  • Multi Processing
    • 計125秒で終了
    • データの合計は18063118171031
  • Single Processing
    • 計494秒で終了
    • データの合計は18063118171031

という感じですね。細かいことを言うとMemory I/OやCPU I/Oなどもあるので単純比較はできないのですが、だいたい4倍くらい早くなりました。

まとめ、考察

あんまりこういうアプローチをしている記事がなかったので書いてみました。
Single Processingだとあんまり速度が出ない理由はプロセス単位でNetwork I/Oの速度制限が掛けられているから、とか、Google Driveの1 Socketあたりの速度制限があるから、とかあるんでしょうか…?読み込みが二回目以降になるとキャッシュが効いて速度がめちゃくちゃ上がることから鑑みるにpygrib.openの速度の問題とかではなくやはりネットワークの速度の問題な気がします。

  1. 最初だけ早いですが、これは途中で読み込みをやめてもう一度読み込んだところキャッシュが効いたので早くなったからです。本筋とは関係ないですが、、

  2. 一年当たりのデータ数が7月1日から9月30日までの92個、それが8年分です。

  3. 細かいことをいうと浮動小数点は足し算の順番によって結果が変わってきたりするので、データの和をintに切り捨てて更にその和を取っている感じです。

4
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?