Python で「for 文で取り出すアイテムをオンデマンドに生成したい / 何周目かによって切り出しサイズを変えたい」といったときにはイテレータの自作が、「(with 文で) 後処理の実行を保証させたい」といったときにはコンテキストマネージャの自作が便利です。
これらの型を自作するには、守るべき約束 (プロトコル) を守ったクラスを実装すればよいです。双方の型ともプロトコルを自分で直接実装しなくて済むシンタクスシュガー (便利な書き方) も提供されていますが、それもプロトコルとその役割を踏まえて使う方が勝手がわかりやすいと思います。1
ただ双方の型ともそんなに頻繁に実装するものでもなく、自作したいときにはプロトコルを忘れがちなので、備忘メモを書きます。
イテレータ
イテレータは、for 文で取り出すアイテムの生成 / 読み込み / 切り出しを都度行いたいようなときや、何エポック目のループなのかを保持したいときなどに便利です。
イテレータプロトコル (約束) ・ その役割
イテレータ型は以下を守る必要があります。
-
(1) 自分自身を返すメソッド
__iter__()を持ちます。- このメソッドは
for文にイテレータを渡したとき一度実行されます。
- このメソッドは
-
(2) 次のアイテムを返すメソッド
__next__()を持ちます。もし次のアイテムがなければこのメソッドは StopIteration を送出します。一度 StopIteration を送出したら、以降の呼び出しでも例外を送出し続けます。- このメソッドは
for文でアイテムを取り出す度に実行されます。
- このメソッドは
プロトコルを直接実装する以外の方法
プロトコルを直接実装しなくても、以下でもイテレータを作成できます。
- そのクラスのメソッド
__iter__()をジェネレータ関数 (yield式を含む関数) にすれば、そのオブジェクトはfor文に渡すとイテレータオブジェクトを返します。2- ジェネレータ関数はイテレータオブジェクトを返し、そのイテレータオブジェクトは次の要素が取り出される度にジェネレータ関数内の次の
yield式まで実行します。
- ジェネレータ関数はイテレータオブジェクトを返し、そのイテレータオブジェクトは次の要素が取り出される度にジェネレータ関数内の次の
- 単にジェネレータ関数を実装すれば、これはイテレータオブジェクトを返します。
イテレータの例
以下の A., B., C. いずれもフィボナッチ数列を取り出せます。どれで書くかは用途や好みによりますが、状態を保持したい / 設定値を持たせたい / 継承で拡張したい場合はクラスにするのが向きます。他方、単純な反復処理なら yield 式がすっきりするかもしれません。
class Fibonacci:
def __iter__(self):
print('===== __iter__() was called. =====')
self.a, self.b = 1, 1
self.counter = 0
return self
def __next__(self):
if self.counter == 10:
raise StopIteration
self.counter += 1
x = 1
if self.counter > 2:
x = self.a + self.b
self.a, self.b = self.b, x
return x
if __name__ == '__main__':
fi = Fibonacci()
# for 文に渡せば反復処理できます
for x in fi:
print(x)
# また for 文に渡せばまた反復処理できます
for x in fi:
print(x)
# 自前で順次取り出すこともできます (あまりやらないと思いますが)
iter(fi) # __iter__() をよびます
print(next(fi)) # __next__() をよびます
print(next(fi)) # __next__() をよびます
print(next(fi)) # __next__() をよびます
class Fibonacci:
def __iter__(self):
print('===== __iter__() was called. =====')
a, b = 1, 1
for counter in range(1, 11):
x = 1
if counter > 2:
x = a + b
a, b = b, x
yield x
if __name__ == '__main__':
fi = Fibonacci()
for x in fi:
print(x)
for x in fi:
print(x)
def fibonacci():
a, b = 1, 1
for counter in range(1, 11):
x = 1
if counter > 2:
x = a + b
a, b = b, x
yield x
if __name__ == '__main__':
for x in fibonacci():
print(x)
for x in fibonacci():
print(x)
コンテキストマネージャ
コンテキストマネージャは with 文に渡して利用し、ブロック (with 節の中身) を出るとき (普通に出たときでも例外で追い出されたときでも) に必ず指定の後処理を実行させることができます。後処理の実行を保証させたいときに便利です。
コンテキストマネージャプロトコル (約束) ・ その役割
-
(1) メソッド
__enter__()を持ちます。- このメソッドはブロックに入るとき呼び出されます。
- このメソッドの返り値は
as節で取り出すことができます。
-
(2) ブロック内で送出された例外の型と値とトレースバックを受け取り (例外が送出されなかったらすべて None)、ブール値を返すメソッド
__exit__(exc_type, exc_val, exc_tb)を持ちます。-
このメソッドはブロックから出るとき (普通に出たときでも例外で追い出されたときでも) 必ず呼び出されます。
with文 のドキュメントの「これは次と等価です:」がわかりやすいです。- ブロック内で送出された例外はこのメソッドの返り値が True ならばそこで握りつぶされ、False ならば再送出されます。
-
このメソッドはブロックから出るとき (普通に出たときでも例外で追い出されたときでも) 必ず呼び出されます。
プロトコルを直接実装する以外の方法
- 標準モジュール contextlib のデコレータ
@contextmanagerで、1 回だけyield式をもつジェネレータ関数をデコレートすれば、これはコンテキストマネージャオブジェクトを返します。with文実行時はまずジェネレータ関数がyield式まで実行され (返却値はas節で取り出せます)、その後にブロックが実行され、ブロックを出るとジェネレータ関数の続きが実行されます。
コンテキストマネージャ型の例
以下の A., B. いずれもブロックの所要時間を計測します。いずれも例外を握りつぶさないので、ブロック内の # raise ValueError のコメントアウトを外した場合は print('end') には到達しません。例外を握りつぶして到達させたい場合は、A. なら __exit__() の返り値を True にし、B. なら except 節を記述して例外を握りつぶします。
import time
class MeasureTime:
def __init__(self, info):
self.info = info
def __enter__(self):
self.start = time.perf_counter()
return self.info
def __exit__(self, exc_type, exc, tb):
elapsed = time.perf_counter() - self.start
mins, secs = divmod(elapsed, 60)
self.info['elapsed'] = f'{int(mins)} min {int(secs)} sec'
print(self.info)
return False # 例外は握りつぶさない
if __name__ == '__main__':
info = {'a': 123}
with MeasureTime(info) as info:
print(info)
# raise ValueError
time.sleep(3)
print('end')
from contextlib import contextmanager
import time
@contextmanager
def measure_time(info):
start = time.perf_counter()
try:
yield info
# except:
# pass
finally:
elapsed = time.perf_counter() - start
mins, secs = divmod(elapsed, 60)
info['elapsed'] = f'{int(mins)} min {int(secs)} sec'
print(info)
if __name__ == '__main__':
info = {'a': 123}
with measure_time(info) as info:
print(info)
# raise ValueError
time.sleep(3)
print('end')