LoginSignup
2
6

More than 5 years have passed since last update.

深さの全く違うネストのイテレータを flattenする方法

Last updated at Posted at 2018-09-14

目標

今回はflattenする対象として

target1 = [(1, 2, 3, [4, 5, 6]), 7, 8, [9, (10, range(11, 14), [14])]]
target2 = (range(1, 4), [4, 5, (6, 7, [8, 9, (10, 11, 12, [13, 14])])])

の2つを作った関数を用いて

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

にするのを目標にします。

標準ライブラリの限界

pythonの標準ライブラリには itertools.chan というモジュールがあります
こちらは

itertools.chain(*iterables)
先頭の iterable の全要素を返し、次に2番目の iterable の全要素を返し、と全 iterable の要素を返すイテレータを作成します。連続したシーケンスを一つのシーケンスとして扱う場合に使用します。およそ次と等価です:

def chain(*iterables):
    # chain('ABC', 'DEF') --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

classmethod chain.from_iterable(iterable)
chain() のためのもう一つのコンストラクタです。遅延評価される iterable 引数一つから連鎖した入力を受け取ります。この関数は、以下のコードとほぼ等価です:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

というように、一見flattenしてくれそうですが、実際のところ

>>> from itertools import chain
>>> target1 = [(1, 2, 3, [4, 5, 6]), 7, 8, [9, (10, range(11, 14), [14])]]
>>> target2 = (range(1, 4), [4, 5, (6, 7, [8, 9, (10, 11, 12, [13, 14])])])
>>> list(chain.from_iterable(target1))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'int' object is not iterable

というふうに、2重以上にネストされたイテレータしか受け取ってくれません。

>>> list(chain.from_iterable([[1, 2, 3], [4, 5, 6]]))
[1, 2, 3, 4, 5, 6]

ちょうどこんな具合です。

しかし2重以上にネストされていればすべてflattenできるのかといわれたらそうでもなく。

>>> list(chain.from_iterable([[1, 2, 3], [[4, 5, 6]]]))
[1, 2, 3, [4, 5, 6]]

このように2重目のイテレータしかflattenしてくれません。

じゃあどうするのか

自分で関数を作ります

イテレータをflattenするには、yield from を使うと良いでしょう。

その前に、 yield, yield from って?

yield, yield from というのは、ジェネレータのreturn に当たるものです。
yield と return の違いは、そこで関数・ジェネレータを終わるか終わらないかです。

>>> def func():
...     yield 1
...     yield 2
...     yield 3
... 
>>> s = func()
>>> next(s)
1
>>> next(s)
2
>>> next(s)
3
>>> next(s)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

こちらがジェネレータになります。
ジェネレータはnext(ジェネレータ)で値を取り出しますが、複数回取り出せていますよね。
これは、yieldが3つあるため、3回値を返されているからです。

関数でやった場合、最初の1しか返ってきません。

>>> def func():
...     return 1
...     return 2
...     return 3
... 
>>> func()
1

yield from というのは、イテレータを対象にすることで、まるで複数個のyieldがあるかのように振る舞うようなステートメントのことです。

>>> def func():
...     yield from [1, 2, 3]
... 
>>> s = func()
>>> next(s)
1
>>> next(s)
2
>>> next(s)
3
>>> next(s)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

つまり、

( yield 1
  yield 2
  yield 3) 
== 
( yield from [1, 2, 3])

ということです。
これに更に再帰を組み合わせることで、実装できますよね。

ということで、やってみましょう。

実装

イテレータを引数にとるジェネレータを作成しましょう。

def flatten(iterables):
    pass

このイテレータの要素のうち、イテレータであるものはyield from そうでないものは yield を使いましょう。

def flatten(iterables):
    for element in iterables:
        if "element が iterable":
            if len(element) < 2:
                "elseに回す"
            yield from flatten(element)
        else:
            yield element

このifの部分ですが、面倒なので(isinstanceとかtypeを使えばできますが、指定するのが面倒),try-catchで良いと思います。

def flatten(iterables):
    for element in iterables:
        try:
            print(type(element))
            if len(element) < 2 and type(element) == str:
                raise(TypeError)
            yield from flatten(element)
        except TypeError:
            yield element

こんな感じで。lenの部分は、str等による無限再帰対策です。
もし文字列を1文字ずつぶつ切りにするのが嫌なら

def flatten(iterables, mercy=(str,)):
    for element in iterables:
        try:
            if type(element) in mercy:
                raise(TypeError)
            yield from flatten(element, mercy)
        except TypeError:
            yield element

といった感じで除外すればいいと思います。

結果

>>> def flatten(iterables, mercy=(str,)):
...     for element in iterables:
...         try:
...             if isinstance(element, mercy):
...                 raise(TypeError)
...             yield from flatten(element)
...         except TypeError:
...             yield element
... 
>>> list(flatten(target1))
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
>>> list(flatten(target2))
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

ちゃんとできましたね。

rangeは開きたくない!というなら、

>>> list(flatten(target1, mercy=(str, range)))
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, range(11, 14), 14]
>>> list(flatten(target2, mercy=(str, range)))
[range(1, 4), 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

こういった形で避けられるでしょう。

みなさんがflattenで悩んだとき、このページを見に来ていただければ助けになると思います。

それではみなさん、たのしいPython Lifeを!

2
6
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
6