4
3

More than 3 years have passed since last update.

NNablaでtensorboardを使う

Posted at

tensorboard は学習時にロスカーブを書いたり、ヒストグラムや画像を描画したりするのがとても便利なツールです。私は最近ソニー製のニューラルネットワークフレーム NNabla (https://nnabla.org/) を使っていますが、可視化ツールがなかったので、NNabla でも tensorboard を使えるように、python のパッケージを作りました。

基本は "tensorboardX for pytorch" をベースに作りました。

使い方

基本的には demp.py を実行してもらうとどんな感じか分かるかと思います。スカラ、ヒストグラム、画像などの描画に対応しています。

# Install
pip install 'git+https://github.com/naibo-code/nnabla_tensorboard.git'

# Demo
python examples/demo.py

スカラ

scaler

ヒストグラム

histogram

文字出力

text

NNabla + tensorboard で MNIST の学習を可視化

NNabla はこちらのリポジトリ https://github.com/sony/nnabla-examples/ で幾つかの examples を提供しています。今回はその中から MNIST の学習コード を使って、リアルタイムに学習結果を tensorboard で可視化してみました。

変更すべきのはこちらの2つの関数だけです(NEW と書かれた部分だけ。)。あとファイルの先頭にfrom nnabla_tensorboard import SummaryWriter でパッケージをインポートします。

from nnabla_tensorboard import SummaryWriter


def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)

    # For tensorboard (NEW)
    tb_writer = SummaryWriter(args.monitor_path)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation (NEW)
            validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)

        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        loss.data.cast(np.float32, ctx)
        pred.data.cast(np.float32, ctx)
        e = categorical_error(pred.d, label.d)

        # Instead of using nnabla.monitor, use nnabla_tensorboard. (NEW)
        if i % args.val_interval == 0:
            tb_writer.add_image('image/train_data_{}'.format(i), image.d[0])

        tb_writer.add_scalar('train/loss', loss.d.copy(), global_step=i)
        tb_writer.add_scalar('train/error', e, global_step=i)
        monitor_time.add(i)

    validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)

    # append F.Softmax to the prediction graph so users see intuitive outputs
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': args.batch_size,
             'outputs': {'y': F.softmax(vpred)},
             'names': {'x': vimage}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x'],
             'output': ['y']}]}
    save.save(os.path.join(args.model_save_path,
                           '{}_result.nnp'.format(args.net)), runtime_contents)

    tb_writer.close()
def validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer):
    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        vpred.data.cast(np.float32, ctx)
        ve += categorical_error(vpred.d, vlabel.d)
    tb_writer.add_scalar('test/error', ve / args.val_iter, i)

NNabla + tensorboard : MNIST の実行結果

学習カーブ
mnist_curve.png

入力イメージもplotしてみた。
mnist_image.png

自作スクリプトで描画したりする必要がなく、やっぱり tensorboard は便利ですね。

追加したい機能

  • Network graph を tensorboard に表示する機能。
  • NNabla をうまく使えば、中間層のデータの可視化も tensorboard でできちゃうかもしれません。(まだ色々調べ中・・・)
4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3