概要
Jupyter上で sqlalchemy
(SQLMagicの%sql)を使って
データベースからレコードを取り出し並列計算をしようとしたときに
エラーが出て対処に苦労したので備忘録としてまとめます。
環境
postgreSQL(version9.5.7)
Anaconda(version4.0.0)
jupyter(version1.0.0)
sqlalchemy(version1.1.9)
joblib(version0.11)
背景
Jupyter上でPostgreSQLのデータベースに
sqlalchemy
のSQLMagicを用いてアクセスし、
取り出したデータを用いて joblib
で並列計算を行おうとしていました。
records = %sql SELECT * FROM TABLE_A
[(1, 'Alice'),
(2, 'Bob'),
(3, 'Cayce')]
other_records = %sql SELECT * FROM TABLE_B
[(1, 100),
(2, 200),
(3, 300),
(4, 400),
(5, 500)]
def parallel_func(record, other_records):
...
#省略
from joblib import Parallel, delayed
r = Parallel(n_jobs=-1, verbose=5) (delayed(parallel_func) (record, other_records) for record in records)
other_recordsは固定のデータなのでそのまま引数に与えて、
変化させるrecordsのレコードをfor文で一行ずつとりだして
parallel_funcを行う予定でした。
しかし、実行すると以下のエラーが出てしまいました
---------------------------------------------------------------------------
PicklingError Traceback (most recent call last)
<ipython-input-46-a4803b957f2d> in <module>()
----> 1 r = Parallel(n_jobs=6, verbose=3) ( delayed(update_fragment_ids) (record, parents) for record in records )
/home/yamasakih/.pyenv/versions/anaconda-4.0.0/envs/rdkit/lib/python3.6/site-packages/joblib/parallel.py in __call__(self, iterable)
787 # consumption.
788 self._iterating = False
--> 789 self.retrieve()
790 # Make sure that we get a last message telling us we are done
791 elapsed_time = time.time() - self._start_time
/home/yamasakih/.pyenv/versions/anaconda-4.0.0/envs/rdkit/lib/python3.6/site-packages/joblib/parallel.py in retrieve(self)
697 try:
698 if getattr(self._backend, 'supports_timeout', False):
--> 699 self._output.extend(job.get(timeout=self.timeout))
700 else:
701 self._output.extend(job.get())
/home/yamasakih/.pyenv/versions/anaconda-4.0.0/envs/rdkit/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
606 return self._value
607 else:
--> 608 raise self._value
609
610 def _set(self, i, obj):
/home/yamasakih/.pyenv/versions/anaconda-4.0.0/envs/rdkit/lib/python3.6/multiprocessing/pool.py in _handle_tasks(taskqueue, put, outqueue, pool, cache)
383 break
384 try:
--> 385 put(task)
386 except Exception as e:
387 job, ind = task[:2]
/home/yamasakih/.pyenv/versions/anaconda-4.0.0/envs/rdkit/lib/python3.6/site-packages/joblib/pool.py in send(obj)
369 def send(obj):
370 buffer = BytesIO()
--> 371 CustomizablePickler(buffer, self._reducers).dump(obj)
372 self._writer.send_bytes(buffer.getvalue())
373 self._send = send
PicklingError: Can't pickle <built-in function input>: it's not the same object as builtins.input
リストやタプルはpickleできなかったかな?と
ドキュメントを確認しましたが、特に問題はありませんでした。
原因
色々試行錯誤した後にどうしてもわからず
Pythonが得意な知人のお兄さんに質問したところ、
type(records)
type(other_records)
を一度試した方が良いと言われたので実行してみました。
>>> type(records)
>>> sql.run.ResultSet
>>> type(other_records)
>>> sql.run.ResultSet
と表示されました。
リストだと勘違いしていたレコードのデータは
実際にはsql.run.ResultSet
というクラスのオブジェクトでした。
よくよく考えると DataFrame()
と言う名前の
pandas
のデータフレームにする list
にはないメソッドありますしね。
出力がリスト風だったので完全に勘違いしていました。
対処
sql.run.ResultSet
のオブジェクトが list
に変換できるか試してみました。
>>> other_records = list(other_records)
>>> type(other_records)
>>> list
無事 list
に変換されました!そして中身も変わってしまうことがないのを確認しました。
そこで、改めて以下のように並列計算を実行したところうまく計算が進みました!
from joblib import Parallel, delayed
r = Parallel(n_jobs=-1, verbose=5) (delayed(parallel_func) (record, list(other_records)) for record in records)
なお、 record
は list
する必要が無いようです。
sql.run.ResultSet
オブジェクトからレコードをforで取り出すと
sqlalchemy.engine.result.RowProxy
オブジェクトになるからか、
そもそもfor文で回す部分はpickleの処理をしないから?
ではないかということです。
(詳しい方教えていただけたら助かります)
最後に
以上で sqlalchemy
で取り出したレコードを引数にして joblib
で
並列計算を行う時につまずいたところの説明を終わります。
戒め案件として、Pythonの得意なお兄さんに感謝しつつまとめておきました。
皆様の楽しい並列計算ライフの助けになれば幸いです。
ここまでお読みいただきありがとうございました。