14
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ソフトウェアテストAdvent Calendar 2017

Day 10

Jupyter Notebookでpytest

Last updated at Posted at 2017-12-09

Jupyter Notebookでpytest

Jupyter Notebookで作成した.ipynbファイルにpytestで単体テストするツールとしてpytest_ipynbがあります。
https://pypi.python.org/pypi/pytest-ipynb

こちらでご紹介されているとおり、py.test -vで呼び出してセル毎にテストすることができる優れものです。
インストール方法は簡単で、以下を実行するだけで使い始めることができます。

pip install pytest-ipynb

日々機械学習やディープラーニングをJupyter Notebookで書いているので、試しに使ってみました。

とりあえずMNIST

試しにMNISTのニューラルネットワークを以下のとおり作って実行してみました。
なお、以下プログラム中の#####################はセルの区切りです。
実際のNotebookには書いていません。


#####################
"""fixture"""
import os
import time
import string
import pytest

#####################
"""import"""
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

#####################
"""data preparation"""
batch_size = 128
num_classes = 10
epochs = 4

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

assert x_train.shape == (60000,28,28,1)
assert x_test.shape == (10000,28,28,1)
assert y_train.shape == (60000, )
assert y_test.shape == (10000, )
assert input_shape == (img_rows, img_cols, 1)

#####################
"""prepare x"""
@pytest.mark.timeout(180)
def prepareX(x_train, x_test):
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    return x_train,x_test
x_train,x_test = prepareX(x_train,x_test)

#####################
"""print x shape"""
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

#####################
"""prepare y"""
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

assert y_train.shape == (60000,10)
assert y_test.shape == (10000,10)

#####################
"""build model"""
@pytest.mark.timeout(10)
def test_buildModel():
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3),
                     activation='relu',
                     input_shape=input_shape))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))
    return model
model = test_buildModel()

#####################
"""compile model"""
@pytest.mark.timeout(10)
def test_compileModel(model):
    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adadelta(),
                  metrics=['accuracy'])
    return model
model = test_compileModel(model)

#####################
"""train model"""
@pytest.mark.timeout(600)
def test_trainModel(model):
    history = model.fit(x_train, y_train,
                        batch_size=batch_size,
                        epochs=epochs,
                        verbose=0,
                        validation_data=(x_test, y_test))
    return model,history
model,history = test_trainModel(model)

#####################
"""score model"""
@pytest.mark.timeout(120)
def test_scoreModel(model):
    score = model.evaluate(x_test, y_test, verbose=0)
    print(score)
    return score
score = test_scoreModel(model)
assert score[0] < 0.05
assert score[1] > 0.98

#####################
"""print score"""
print('Test loss:', score[0])
print('Test accuracy:', score[1])

つまずいた点

Jupyter Notebook上に書いたassert xxxがセル毎に評価されます。
Jupyter Notebookで各セルを実行する際にもassertされます。
ただし、ここで使っているのはpytestであって、pytest-ipynbではありません。
pytestとしてassertされています。

pytest-ipynbは.ipynbファイルをpy.test -vでテストするためのツールになります。
コンソール上でpy.test -vを実行すれば、test_**.ipynbファイルの各セルに対してテストを実行してくれるものです。

そしてpytest-ipynbのタイムアウトは、以下のようにセルに記載したタイムアウト値を参照しません。
どこを参照するかというと、pytest-ipynbがインストールされたpytest_ipynb/plugin.pyの中に20秒とハードコーディングされています。
py.test -v --timeout=600のように--timeoutオプションを付けても無駄でした。
そのため、20秒を超えるセルをテストすると、20秒経過してFailureになります。

1.PNG

というわけで、pytest-ipynbでコンソール上から.ipynbファイルをテストする場合、実行に20秒を超えるセルのためには、pytest_ipynb/plugin.pyの以下部分を変更しましょう。


def runtest(self):
    #self.parent.runner.km.restart_kernel()

    if self.parent.notebook_folder:
        self.parent.runner.kc.execute(
"""import os
os.chdir("%s")""" % self.parent.notebook_folder)

    if ("SKIPCI" in self.cell_description) and ("CI" in os.environ):
        pass
    else:
        if self.parent.fixture_cell:
            self.parent.runner.kc.execute(self.parent.fixture_cell.input, allow_stdin=False)
        msg_id = self.parent.runner.kc.execute(self.cell.input, allow_stdin=False)
        if self.cell_description.lower().startswith("fixture") or self.cell_description.lower().startswith("setup"):
            self.parent.fixture_cell = self.cell
        timeout = 20 #←← ここ!!! 
        while True:
            try:
                msg = self.parent.runner.kc.get_shell_msg(block=True, timeout=timeout)
                if msg.get("parent_header", None) and msg["parent_header"].get("msg_id", None) == msg_id:
                    break
            except Empty:
                raise IPyNbException("Timeout of %d seconds exceeded executing cell: %s" (timeout, self.cell.input))

        reply = msg['content']

        if reply['status'] == 'error':
            raise IPyNbException(self.cell_num, self.cell_description, self.cell.input, '\n'.join(reply['traceback']))

viかなにかで編集しましょう。
これがわからなくて同じテストを10回くらい流しました。

0.png

14
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?