# -------------------------------------------------------------------------------
# Use Chainer (http://chainer.org/) from R
#   with reticulate package (https://github.com/rstudio/reticulate)
#
#   original code in Python:
#     https://github.com/pfnet/chainer/blob/master/examples/mnist/train_mnist.py
# -------------------------------------------------------------------------------
# setup R interface to Python --------------------------------------------------
library(reticulate)
py <- import("__builtin__")
main <- import("__main__")
# define python class generator ------------------------------------------------
# NOTE: class generation by directrly calling py$type() does not work (why?)
main <- py_run_string("
def pyClass(classname, inherit=None, member=None):
    if inherit is None:
        inherit = ()
    if member is None:
        member = dict()
    return type(classname, inherit, member)
")
pyClass <- function(classname, inherit=NULL, member=NULL) {
  .class <- .pyClass(classname, inherit, member)
  class <- function (...) {
    args <- list()
    keywords <- list()
    dots <- list(...)
    names <- names(dots)
    if (!is.null(names)) {
      for (i in 1:length(dots)) {
        name <- names[[i]]
        if (nzchar(name)) 
          if (is.null(dots[[i]])) 
            keywords[name] <- list(NULL)
          else keywords[[name]] <- dots[[i]]
          else if (is.null(dots[[i]])) 
            args[length(args) + 1] <- list(NULL)
          else args[[length(args) + 1]] <- dots[[i]]
      }
    }
    else {
      args <- dots
    }
    result <- py_call(.class, args, keywords)
    if (is.null(result)) 
      invisible(result)
    else result
  }
  structure(class, py_object=.class)
}
environment(pyClass)$.pyClass <- main$pyClass
py$delattr(main, "pyClass")
# chainer modules --------------------------------------------------------------
chainer    <- import("chainer")
F          <- import("chainer.functions")
L          <- import("chainer.links")
training   <- chainer$training
extensions <- training$extensions
# Network definition -----------------------------------------------------------
MLP <- pyClass('MLP', inherit = tuple(chainer$Chain),
  member = dict(
    `__init__` = function(self, n_units, n_out) {
      py$super(MLP, self)$`__init__`(
        # the size of the inputs to each layer will be inferred
        l1 = L$Linear(NULL, n_units),  # n_in -> n_units
        l2 = L$Linear(NULL, n_units),  # n_units -> n_units
        l3 = L$Linear(NULL, n_out)    # n_units -> n_out
      )
    },
    `__call__` = function(self, x) {
      h1 <- F$relu(self$l1(x))
      h2 <- F$relu(self$l2(h1))
      return(self$l3(h2))
    }
  )
)
# main -------------------------------------------------------------------------
args <- list(
  batchsize = 100L,    # Number of images in each mini-batch
  epoch     = 10L,     # Number of sweeps over the dataset to train
  unit      = 1000L,   # Number of units
  out       = 'result' # Directory to output the result
)
# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model <- L$Classifier(MLP(args$unit, 10L))
# Setup an optimizer
optimizer <- chainer$optimizers$Adam()
optimizer$setup(model)
# Load the MNIST dataset
mnist_data <- chainer$datasets$get_mnist()
train <- mnist_data[[1]]
test  <- mnist_data[[2]]
train_iter <- chainer$iterators$SerialIterator(train, args$batchsize)
test_iter <- chainer$iterators$SerialIterator(test, args$batchsize,
                                              FALSE, FALSE)
# Set up a trainer
updater <- training$StandardUpdater(train_iter, optimizer)
trainer <- training$Trainer(updater, tuple(args$epoch, 'epoch'), out=args$out)
# Evaluate the model with the test dataset for each epoch
trainer$extend(extensions$Evaluator(test_iter, model))
# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer$extend(extensions$dump_graph('main/loss'))
# Take a snapshot at each epoch
trainer$extend(extensions$snapshot(), trigger=tuple(args$epoch, 'epoch'))
# Write a log of evaluation statistics for each epoch
trainer$extend(extensions$LogReport())
# Save two plot images to the result dir
trainer$extend(
  extensions$PlotReport(c('main/loss', 'validation/main/loss'), 'epoch',
                        file_name='loss.png'))
trainer$extend(
  extensions$PlotReport(c('main/accuracy', 'validation/main/accuracy'),
                        'epoch', file_name='accuracy.png'))
# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer$extend(extensions$PrintReport(
  c('epoch', 'main/loss', 'validation/main/loss',
   'main/accuracy', 'validation/main/accuracy', 'elapsed_time')))
# Print a progress bar to stdout
trainer$extend(extensions$ProgressBar())
# Run the training
trainer$run()
More than 5 years have passed since last update.
Register as a new user and use Qiita more conveniently
- You get articles that match your needs
- You can efficiently read back useful information
- You can use dark theme