PythonでJoblibというライブラリを使ってitertools.groupby(以下groupby)でグループ化したデータを並列実行処理したときに、なぜか上手く動かないことがありました。
解決後に判ったことをふまえて、原因を説明したいと思います。
起きたこと(再現)
JoblibというライブラリのParallelにgroupbyを使ったジェネレータを渡して、並列処理を実装しました。
from joblib import Parallel, delayed
from itertools import groupby
# タグごとのメンバーを表示する
def view_members(tag, members):
print(f"Tag:{tag}")
for member in members:
print(f"Name:{member['name']} Birthday:{member['birthday']}")
# メンバーのデータ
props = [
{"name": "Alice", "birthday": "2000/11/22", "Tag": "A"},
{"name": "Bob", "birthday": "1999/2/15", "Tag": "B"},
{"name": "Carol", "birthday": "1967/1/25", "Tag": "B"},
{"name": "Dave", "birthday": "1997/4/21", "Tag": "A"}
]
# groupbyを使うための前処理
props.sort(key=lambda x: x["Tag"])
# グループ化したメンバーのデータを関数に渡して並列実行
Parallel(n_jobs=4)(delayed(view_members)(
tag,
members
) for tag, members in groupby(props, key=lambda x: x["Tag"]))
そしたら結果は...
Tag:A
Tag:B
タグは表示されているのに、メンバーが表示されないです。
なぜうまく動かないのか、しばらく悩みました...
この問題を理解するために
この問題を理解するために、まずは以下の2つのPythonの仕様を理解しなければいけません。
- listの仕様
- groupbyの仕様
1. listの仕様
なぜいきなりlistなのかと疑問かもしれません。
簡単に伝えておくと、私が使っていたjoblib.Parallelでは、引数で渡したイテラブルな値を内部でlist()
を使って変換しています。ですので、groupbyでグループ化した値をリストに変換するとどうなるかが重要です。
まずはリストのコンストラクタlist()
を説明します。
と言っても、listはcpythonとしてC言語で実装されていますので、さらっと読むだけにしておきます。
2796行目からlistの処理が実装されています。
引用元 | https://github.com/python/cpython/blob/main/Objects/listobject.c
/*[clinic input]
list.__init__
iterable: object(c_default="NULL") = ()
/
Built-in mutable sequence.
If no argument is given, the constructor creates a new empty list.
The argument must be an iterable if specified.
[clinic start generated code]*/
static int
list___init___impl(PyListObject *self, PyObject *iterable)
/*[clinic end generated code: output=0f3c21379d01de48 input=b3f3fe7206af8f6b]*/
{
/* Verify list invariants established by PyType_GenericAlloc() */
assert(0 <= Py_SIZE(self));
assert(Py_SIZE(self) <= self->allocated || self->allocated == -1);
assert(self->ob_item != NULL ||
self->allocated == 0 || self->allocated == -1);
/* Empty previous contents */
if (self->ob_item != NULL) {
(void)_list_clear(self);
}
if (iterable != NULL) {
if (_PyObject_HasLen(iterable)) {
Py_ssize_t iter_len = PyObject_Size(iterable);
if (iter_len == -1) {
if (!PyErr_ExceptionMatches(PyExc_TypeError)) {
return -1;
}
PyErr_Clear();
}
if (iter_len > 0 && self->ob_item == NULL
&& list_preallocate_exact(self, iter_len)) {
return -1;
}
}
PyObject *rv = list_extend(self, iterable);
if (rv == NULL)
return -1;
Py_DECREF(rv);
}
return 0;
}
細かいところはさておき、list(iterable)
を実行したらどうなるかを見てみると、2836行目でlist_extendを呼んでいますね。
つまり、内部的には空のリストに対してlist.extend(iterable)
するのとほぼ同じだと考えられます。
list_extendの実装を見てみます。
868行目からです。
list_extend(PyListObject *self, PyObject *iterable)
/*[clinic end generated code: output=630fb3bca0c8e789 input=9ec5ba3a81be3a4d]*/
{
PyObject *it; /* iter(v) */
Py_ssize_t m; /* size of self */
Py_ssize_t n; /* guess for size of iterable */
Py_ssize_t mn; /* m + n */
Py_ssize_t i;
PyObject *(*iternext)(PyObject *);
/* Special cases:
1) lists and tuples which can use PySequence_Fast ops
2) extending self to self requires making a copy first
*/
if (PyList_CheckExact(iterable) || PyTuple_CheckExact(iterable) ||
(PyObject *)self == iterable) {
PyObject **src, **dest;
iterable = PySequence_Fast(iterable, "argument must be iterable");
if (!iterable)
return NULL;
n = PySequence_Fast_GET_SIZE(iterable);
if (n == 0) {
/* short circuit when iterable is empty */
Py_DECREF(iterable);
Py_RETURN_NONE;
}
m = Py_SIZE(self);
/* It should not be possible to allocate a list large enough to cause
an overflow on any relevant platform */
assert(m < PY_SSIZE_T_MAX - n);
if (list_resize(self, m + n) < 0) {
Py_DECREF(iterable);
return NULL;
}
/* note that we may still have self == iterable here for the
* situation a.extend(a), but the following code works
* in that case too. Just make sure to resize self
* before calling PySequence_Fast_ITEMS.
*/
/* populate the end of self with iterable's items */
src = PySequence_Fast_ITEMS(iterable);
dest = self->ob_item + m;
for (i = 0; i < n; i++) {
PyObject *o = src[i];
Py_INCREF(o);
dest[i] = o;
}
Py_DECREF(iterable);
Py_RETURN_NONE;
}
it = PyObject_GetIter(iterable);
if (it == NULL)
return NULL;
iternext = *Py_TYPE(it)->tp_iternext;
/* Guess a result list size. */
n = PyObject_LengthHint(iterable, 8);
if (n < 0) {
Py_DECREF(it);
return NULL;
}
m = Py_SIZE(self);
if (m > PY_SSIZE_T_MAX - n) {
/* m + n overflowed; on the chance that n lied, and there really
* is enough room, ignore it. If n was telling the truth, we'll
* eventually run out of memory during the loop.
*/
}
else {
mn = m + n;
/* Make room. */
if (list_resize(self, mn) < 0)
goto error;
/* Make the list sane again. */
Py_SET_SIZE(self, m);
}
/* Run iterator to exhaustion. */
for (;;) {
PyObject *item = iternext(it);
if (item == NULL) {
if (PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_StopIteration))
PyErr_Clear();
else
goto error;
}
break;
}
if (Py_SIZE(self) < self->allocated) {
/* steals ref */
PyList_SET_ITEM(self, Py_SIZE(self), item);
Py_SET_SIZE(self, Py_SIZE(self) + 1);
}
else {
int status = app1(self, item);
Py_DECREF(item); /* append creates a new ref */
if (status < 0)
goto error;
}
}
/* Cut back result list if initial guess was too large. */
if (Py_SIZE(self) < self->allocated) {
if (list_resize(self, Py_SIZE(self)) < 0)
goto error;
}
Py_DECREF(it);
Py_RETURN_NONE;
error:
Py_DECREF(it);
return NULL;
}
947行目のfor文をみると、コメントにも書いてある通り、list_extendは受け取ったiterable
(イテラブルなオブジェクト)を最後まで取り出して変数に詰めなおす処理となっています。
結論、list(iterable)
の戻り値は、渡したiterable
をfor文で取り出して詰めなおした値ですね。
2. groupbyの仕様
groupbyについて、公式ドキュメントを引用しながら説明します。
同じキーをもつような要素からなる iterable 中のグループに対して、キーとグループを返すようなイテレータを作成します。key は各要素に対するキー値を計算する関数です。キーを指定しない場合や None にした場合、key 関数のデフォルトは恒等関数になり要素をそのまま返します。通常、iterable は同じキー関数でソート済みである必要があります。
この説明では難しいので、簡単な例を示します。
(分かった方はこの先の説明は不要かもしれません。)
from itertools import groupby
# グループ化するデータ
props = [
{"name": "Alice", "birthday": "2000/11/22", "Tag": "A"},
{"name": "Bob", "birthday": "1999/2/15", "Tag": "B"},
{"name": "Carol", "birthday": "1967/1/25", "Tag": "B"},
{"name": "Dave", "birthday": "1997/4/21", "Tag": "A"}
]
# Tagでグループ化するためにソートする
## {"name": "Alice", "birthday": "2000/11/22", "Tag": "A"},
## {"name": "Dave", "birthday": "1997/4/21", "Tag": "A"},
## {"name": "Bob", "birthday": "1999/2/15", "Tag": "B"},
## {"name": "Carol", "birthday": "1967/1/25", "Tag": "B"}
props.sort(key=lambda x: x["Tag"])
# groupbyの第一引数にグループ化したい値、
# 第二引数にグループ化するときの値を取得する関数を渡す
for tag, members in groupby(props, key=lambda x: x["Tag"]):
print(tag, members)
for member in members:
print(member)
A <itertools._grouper object at 0x000001BF1D5D5C40>
{'name': 'Alice', 'birthday': '2000/11/22', 'Tag': 'A'}
{'name': 'Dave', 'birthday': '1997/4/21', 'Tag': 'A'}
B <itertools._grouper object at 0x000001BF1D5D5C10>
{'name': 'Bob', 'birthday': '1999/2/15', 'Tag': 'B'}
{'name': 'Carol', 'birthday': '1967/1/25', 'Tag': 'B'}
こんな感じで、groupbyはグループ化したい値とグループ化するときのキーを取得する関数を渡すことで、キーの値とitertools._grouperという型のジェネレータ(イテレータの一種)を返すものです。
今回の問題は、このitertools._grouperオブジェクトからメンバーの情報を取り出せないという現象でした。
では、itertools._grouperはどのように生成されているのでしょうか?
公式ドキュメントが紹介している実装イメージを見てみましょう。
01 class groupby:
02 # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
03 # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D
04 def __init__(self, iterable, key=None):
05 if key is None:
06 key = lambda x: x
07 self.keyfunc = key
08 self.it = iter(iterable)
09 self.tgtkey = self.currkey = self.currvalue = object()
10 def __iter__(self):
11 return self
12 def __next__(self):
13 self.id = object()
14 while self.currkey == self.tgtkey:
15 self.currvalue = next(self.it) # Exit on StopIteration
16 self.currkey = self.keyfunc(self.currvalue)
17 self.tgtkey = self.currkey
18 return (self.currkey, self._grouper(self.tgtkey, self.id))
19 def _grouper(self, tgtkey, id):
20 while self.id is id and self.currkey == tgtkey:
21 yield self.currvalue
22 try:
23 self.currvalue = next(self.it)
24 except StopIteration:
25 return
26 self.currkey = self.keyfunc(self.currvalue)
groupbyから値を取り出すときに呼び出されるのは__next__
です。
14行目のwhile文でnext(self.it)
で次の値を取り出していますね。このループが終わるタイミングは、キーの値が切り替わったタイミングです。
そうすると、self.tgtkey
が次のキーの値を持っているself.currkey
で更新されて、キーの値とself._grouper
(=_grouper)で取得したイテレータを返します。
これが一連の流れです。
では、_grouperの詳細を見てみましょう。
引数でself.tgtkey
とself.id
を受け取っていますね。ぱっと見だとやっていることは__next__
と似ています。
大きく違うところは、21行目のyield self.currvalue
で呼び出し毎に現在の値を返すようなジェネレータになっていることと、while文の条件にself.id
が使われていることですね。
self.id
は__next__
が一番最初に呼ばれたときにobject()
で初期化されます。生成したオブジェクトをidのように扱っているということです。
つまり、_grouperで生成されたジェネレータは、値が取り出せなくなるパターンが3つあります。
- 最後のキー以外に対する値を取り出している場合、値を取り出していき26行目で
self.currkey
が次のキーに切り替わってself.currkey == tgtkey
がFalseになる。 - 最後のキーに対する値を取り出している場合、23行目でStopIteration例外が起こるまで呼ばれてreturnする。
- 全て値を取り出す前に、親であるgroupby(の
__next__
)が呼ばれてself.id
が初期化されself.id is id
がFalseになる。
原因
今回の問題が起きた原因をまとめると以下です。
-
groupbyのイテレータが
Parallel
の中でlist(iterable)
のように渡された。listの仕様で説明したように、一旦全てgroupbyのイテレータが消費されて、キーとitertools._grouperオブジェクトがlist型の変数に詰め替えられた。 -
groupbyの仕様のパターン3の通り、groupbyのイテレータが次に進んでしまった後では内部の
self.id
が変わってしまっているため、itertools._grouperオブジェクトから値を取り出すことができなくなった。
このように、Parallelから並列実行する関数にitertools._grouperオブジェクトであるmembers
が渡された時にはgroupbyのイテレータが全て消費されており、一つも値が取り出せなかったのです。
これはitertools._grouperオブジェクトの特徴であるため、Parallel以外を使った場合も起こる可能性があります。
解決策
イテレータを渡した先の内部でどのように使われるかがわからない場合、itertools._grouperオブジェクトをリストに変換するのがシンプルな解決策です。
Parallel(n_jobs=4)(delayed(view_members)(
tag,
+ list(members)
) for tag, members in groupby(props, key=lambda x: x["Tag"]))
どうしてもitertools._grouperオブジェクトをそのまま使いたいという場合は、copy.deepcopyでコピーすることでも回避できます。(メモリ節約のためにgroupbyの中でイテレータを共有するような設計だと思うので、ちょっともったいない気もします。)
+ from copy import deepcopy
Parallel(n_jobs=4)(delayed(view_members)(
tag,
+ deepcopy(members)
) for tag, members in groupby(props, key=lambda x: x["Tag"]))
おわりに
ちょっとしたつまづきでしたが、ちゃんと理解しようと深堀りすると意外と大きな学びがありました。
まさかC言語の実装まで読むことにはなるとは思いませんでしたが、結果として今までなんとなくしか理解していなかったイテラブルなオブジェクトの管理やイテレータについても深く理解できました。