Python
python3
flatten

深さの全く違うネストのイテレータを flattenする方法

目標

今回はflattenする対象として

target1 = [(1, 2, 3, [4, 5, 6]), 7, 8, [9, (10, range(11, 14), [14])]]
target2 = (range(1, 4), [4, 5, (6, 7, [8, 9, (10, 11, 12, [13, 14])])])

の2つを作った関数を用いて

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

にするのを目標にします。

標準ライブラリの限界

pythonの標準ライブラリには itertools.chan というモジュールがあります
こちらは

itertools.chain(*iterables)
先頭の iterable の全要素を返し、次に2番目の iterable の全要素を返し、と全 iterable の要素を返すイテレータを作成します。連続したシーケンスを一つのシーケンスとして扱う場合に使用します。およそ次と等価です:

def chain(*iterables):
    # chain('ABC', 'DEF') --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

classmethod chain.from_iterable(iterable)
chain() のためのもう一つのコンストラクタです。遅延評価される iterable 引数一つから連鎖した入力を受け取ります。この関数は、以下のコードとほぼ等価です:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

というように、一見flattenしてくれそうですが、実際のところ

>>> from itertools import chain
>>> target1 = [(1, 2, 3, [4, 5, 6]), 7, 8, [9, (10, range(11, 14), [14])]]
>>> target2 = (range(1, 4), [4, 5, (6, 7, [8, 9, (10, 11, 12, [13, 14])])])
>>> list(chain.from_iterable(target1))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'int' object is not iterable

というふうに、2重以上にネストされたイテレータしか受け取ってくれません。

>>> list(chain.from_iterable([[1, 2, 3], [4, 5, 6]]))
[1, 2, 3, 4, 5, 6]

ちょうどこんな具合です。

しかし2重以上にネストされていればすべてflattenできるのかといわれたらそうでもなく。

>>> list(chain.from_iterable([[1, 2, 3], [[4, 5, 6]]]))
[1, 2, 3, [4, 5, 6]]

このように2重目のイテレータしかflattenしてくれません。

じゃあどうするのか

自分で関数を作ります

イテレータをflattenするには、yield from を使うと良いでしょう。

その前に、 yield, yield from って?

yield, yield from というのは、ジェネレータのreturn に当たるものです。
yield と return の違いは、そこで関数・ジェネレータを終わるか終わらないかです。

>>> def func():
...     yield 1
...     yield 2
...     yield 3
... 
>>> s = func()
>>> next(s)
1
>>> next(s)
2
>>> next(s)
3
>>> next(s)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

こちらがジェネレータになります。
ジェネレータはnext(ジェネレータ)で値を取り出しますが、複数回取り出せていますよね。
これは、yieldが3つあるため、3回値を返されているからです。

関数でやった場合、最初の1しか返ってきません。

>>> def func():
...     return 1
...     return 2
...     return 3
... 
>>> func()
1

yield from というのは、イテレータを対象にすることで、まるで複数個のyieldがあるかのように振る舞うようなステートメントのことです。

>>> def func():
...     yield from [1, 2, 3]
... 
>>> s = func()
>>> next(s)
1
>>> next(s)
2
>>> next(s)
3
>>> next(s)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

つまり、

( yield 1
  yield 2
  yield 3) 
== 
( yield from [1, 2, 3])

ということです。
これに更に再帰を組み合わせることで、実装できますよね。

ということで、やってみましょう。

実装

イテレータを引数にとるジェネレータを作成しましょう。

def flatten(iterables):
    pass

このイテレータの要素のうち、イテレータであるものはyield from そうでないものは yield を使いましょう。

def flatten(iterables):
    for element in iterables:
        if "element が iterable":
            if len(element) < 2:
                "elseに回す"
            yield from flatten(element)
        else:
            yield element

このifの部分ですが、面倒なので(isinstanceとかtypeを使えばできますが、指定するのが面倒),try-catchで良いと思います。

def flatten(iterables):
    for element in iterables:
        try:
            print(type(element))
            if len(element) < 2 and type(element) == str:
                raise(TypeError)
            yield from flatten(element)
        except TypeError:
            yield element

こんな感じで。lenの部分は、str等による無限再帰対策です。
もし文字列を1文字ずつぶつ切りにするのが嫌なら

def flatten(iterables, mercy=(str,)):
    for element in iterables:
        try:
            if type(element) in mercy:
                raise(TypeError)
            yield from flatten(element, mercy)
        except TypeError:
            yield element

といった感じで除外すればいいと思います。

結果

>>> def flatten(iterables, mercy=(str,)):
...     for element in iterables:
...         try:
...             if isinstance(element, mercy):
...                 raise(TypeError)
...             yield from flatten(element)
...         except TypeError:
...             yield element
... 
>>> list(flatten(target1))
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
>>> list(flatten(target2))
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

ちゃんとできましたね。

rangeは開きたくない!というなら、

>>> list(flatten(target1, mercy=(str, range)))
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, range(11, 14), 14]
>>> list(flatten(target2, mercy=(str, range)))
[range(1, 4), 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

こういった形で避けられるでしょう。

みなさんがflattenで悩んだとき、このページを見に来ていただければ助けになると思います。

それではみなさん、たのしいPython Lifeを!