Edited at

深さの全く違うネストのイテレータを flattenする方法


目標

今回はflattenする対象として

target1 = [(1, 2, 3, [4, 5, 6]), 7, 8, [9, (10, range(11, 14), [14])]]

target2 = (range(1, 4), [4, 5, (6, 7, [8, 9, (10, 11, 12, [13, 14])])])

の2つを作った関数を用いて

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

にするのを目標にします。


標準ライブラリの限界

pythonの標準ライブラリには itertools.chan というモジュールがあります

こちらは


itertools.chain(*iterables)

先頭の iterable の全要素を返し、次に2番目の iterable の全要素を返し、と全 iterable の要素を返すイテレータを作成します。連続したシーケンスを一つのシーケンスとして扱う場合に使用します。およそ次と等価です:

def chain(*iterables):

# chain('ABC', 'DEF') --> A B C D E F
for it in iterables:
for element in it:
yield element

classmethod chain.from_iterable(iterable)

chain() のためのもう一つのコンストラクタです。遅延評価される iterable 引数一つから連鎖した入力を受け取ります。この関数は、以下のコードとほぼ等価です:

def from_iterable(iterables):

# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
for it in iterables:
for element in it:
yield element

というように、一見flattenしてくれそうですが、実際のところ

>>> from itertools import chain

>>> target1 = [(1, 2, 3, [4, 5, 6]), 7, 8, [9, (10, range(11, 14), [14])]]
>>> target2 = (range(1, 4), [4, 5, (6, 7, [8, 9, (10, 11, 12, [13, 14])])])
>>> list(chain.from_iterable(target1))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: 'int' object is not iterable

というふうに、2重以上にネストされたイテレータしか受け取ってくれません。

>>> list(chain.from_iterable([[1, 2, 3], [4, 5, 6]]))

[1, 2, 3, 4, 5, 6]

ちょうどこんな具合です。

しかし2重以上にネストされていればすべてflattenできるのかといわれたらそうでもなく。

>>> list(chain.from_iterable([[1, 2, 3], [[4, 5, 6]]]))

[1, 2, 3, [4, 5, 6]]

このように2重目のイテレータしかflattenしてくれません。


じゃあどうするのか

自分で関数を作ります

イテレータをflattenするには、yield from を使うと良いでしょう。


その前に、 yield, yield from って?

yield, yield from というのは、ジェネレータのreturn に当たるものです。

yield と return の違いは、そこで関数・ジェネレータを終わるか終わらないかです。

>>> def func():

... yield 1
... yield 2
... yield 3
...
>>> s = func()
>>> next(s)
1
>>> next(s)
2
>>> next(s)
3
>>> next(s)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration

こちらがジェネレータになります。

ジェネレータはnext(ジェネレータ)で値を取り出しますが、複数回取り出せていますよね。

これは、yieldが3つあるため、3回値を返されているからです。

関数でやった場合、最初の1しか返ってきません。

>>> def func():

... return 1
... return 2
... return 3
...
>>> func()
1

yield from というのは、イテレータを対象にすることで、まるで複数個のyieldがあるかのように振る舞うようなステートメントのことです。

>>> def func():

... yield from [1, 2, 3]
...
>>> s = func()
>>> next(s)
1
>>> next(s)
2
>>> next(s)
3
>>> next(s)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration

つまり、

( yield 1

yield 2
yield 3)
==
( yield from [1, 2, 3])

ということです。

これに更に再帰を組み合わせることで、実装できますよね。

ということで、やってみましょう。


実装

イテレータを引数にとるジェネレータを作成しましょう。

def flatten(iterables):

pass

このイテレータの要素のうち、イテレータであるものはyield from そうでないものは yield を使いましょう。

def flatten(iterables):

for element in iterables:
if "element が iterable":
if len(element) < 2:
"elseに回す"
yield from flatten(element)
else:
yield element

このifの部分ですが、面倒なので(isinstanceとかtypeを使えばできますが、指定するのが面倒),try-catchで良いと思います。

def flatten(iterables):

for element in iterables:
try:
print(type(element))
if len(element) < 2 and type(element) == str:
raise(TypeError)
yield from flatten(element)
except TypeError:
yield element

こんな感じで。lenの部分は、str等による無限再帰対策です。

もし文字列を1文字ずつぶつ切りにするのが嫌なら

def flatten(iterables, mercy=(str,)):

for element in iterables:
try:
if type(element) in mercy:
raise(TypeError)
yield from flatten(element, mercy)
except TypeError:
yield element

といった感じで除外すればいいと思います。


結果

>>> def flatten(iterables, mercy=(str,)):

... for element in iterables:
... try:
... if isinstance(element, mercy):
... raise(TypeError)
... yield from flatten(element)
... except TypeError:
... yield element
...
>>> list(flatten(target1))
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
>>> list(flatten(target2))
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

ちゃんとできましたね。

rangeは開きたくない!というなら、

>>> list(flatten(target1, mercy=(str, range)))

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, range(11, 14), 14]
>>> list(flatten(target2, mercy=(str, range)))
[range(1, 4), 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

こういった形で避けられるでしょう。

みなさんがflattenで悩んだとき、このページを見に来ていただければ助けになると思います。

それではみなさん、たのしいPython Lifeを!