LoginSignup
0
0

More than 1 year has passed since last update.

pythonでジェネレータを繰り返し利用する

Last updated at Posted at 2021-07-01

ジェネレータを繰り返し利用する

python 3.7

結論

自作のクラスを作成する

class MyGen:
    def __init__(self.num):
        self.num = num

    def __iter__(self):
        for i in range(self.num):
            yield i

gen = MyGen(10)
for i in range(3):
    for j in gen:
        print(i, j)

目次

  • ジェネレータを繰り返し利用したい理由
  • ジェネレータの作成
  • 実際に画像を都度読み込むジェネレータ
  • tqdmで使う

ジェネレータを繰り返し利用したい理由

画像認識用のAIでの学習でIter毎に画像を読み込む処理をする必要がありジェネレータを利用した。
AIの学習ではエポックと呼ばれる全てのデータを1回利用することを1単位と取るモノを複数利用しなければならない。
故に繰り返し使えるジェネレータが必要とされた

ジェネレータの作成

関数におけるジェネレータは繰り返し使うことができない(僕が知っている限りでは)。
次のプログラムは最初の繰り返しのみ結果を返す。


def mygenerator(num:int) -> int:
    for i in range(num):
        yield i

gen = mygenerator(10)
for i in range(3):
    for j in gen:
        print(i, j)

そこでクラスの__iter__メソッドを利用する。参考:Pythonのイテレータとジェネレータ

class MyGen:
    def __init__(self.num):
        self.num = num

    def __iter__(self):
        for i in range(self.num):
            yield i

gen = MyGen(10)
for i in range(3):
    for j in gen:
        print(i, j)

実際に画像を都度読み込むジェネレータ

import cv2
import numpy as np

class MyGen:
    def __init__(self, paths:list):
        self.paths = paths

    def __iter__(self) -> np.ndarray:
        for path in self.paths:
            img = cv2.imread(path)
            # 前処理とかココに記述してもいいかも
            yield img

img_paths = ["./0001.jpg",
             "./0002.jpg",
             "./0003.jpg",]

gen = MyGen(img_paths)

epoch = 3
for ei in range(epoch):
    for img in gen:
        print(img.shape) 

tqdmで使う

繰り返し動作を視覚的に表示してくれるtqdmライブラリという物がある。
コレを利用することで1エポックの時間等が把握できる

# おためし
import tqdm
import time

# 1重ループ
for i in tqdm.trange(10):
    time.sleep(0.1)

# 2重ループ
for i in tqdm.trange(3):
    for j in tqdm.trange(10)
        time.sleep(0.1)

# リスト
items = [i for i in range(10)]
for item in tqdm.tqdm(items):
    time.sleep(0.1)

これをジェネレータでも利用できるようにするには__len__メソッドを利用する。

import tqdm
import time
# 省略のために継承を利用
class MyGen2(MyGen):
    def __len__(self):
        return len(self.paths)

gen = MyGen2(img_paths)

for ei in tqdm.trange(epoch):
    for img in tqdm.tqdm(gen):
        print(img.shape)
        time.sleep(0.5)
0
0
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0