4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Use Chainer from R with reticulate package

Posted at
# -------------------------------------------------------------------------------
# Use Chainer (http://chainer.org/) from R
#   with reticulate package (https://github.com/rstudio/reticulate)
#
#   original code in Python:
#     https://github.com/pfnet/chainer/blob/master/examples/mnist/train_mnist.py
# -------------------------------------------------------------------------------

# setup R interface to Python --------------------------------------------------
library(reticulate)
py <- import("__builtin__")
main <- import("__main__")


# define python class generator ------------------------------------------------
# NOTE: class generation by directrly calling py$type() does not work (why?)
main <- py_run_string("
def pyClass(classname, inherit=None, member=None):
    if inherit is None:
        inherit = ()
    if member is None:
        member = dict()
    return type(classname, inherit, member)
")
pyClass <- function(classname, inherit=NULL, member=NULL) {
  .class <- .pyClass(classname, inherit, member)
  class <- function (...) {
    args <- list()
    keywords <- list()
    dots <- list(...)
    names <- names(dots)
    if (!is.null(names)) {
      for (i in 1:length(dots)) {
        name <- names[[i]]
        if (nzchar(name)) 
          if (is.null(dots[[i]])) 
            keywords[name] <- list(NULL)
          else keywords[[name]] <- dots[[i]]
          else if (is.null(dots[[i]])) 
            args[length(args) + 1] <- list(NULL)
          else args[[length(args) + 1]] <- dots[[i]]
      }
    }
    else {
      args <- dots
    }
    result <- py_call(.class, args, keywords)
    if (is.null(result)) 
      invisible(result)
    else result
  }
  structure(class, py_object=.class)
}
environment(pyClass)$.pyClass <- main$pyClass
py$delattr(main, "pyClass")



# chainer modules --------------------------------------------------------------
chainer    <- import("chainer")
F          <- import("chainer.functions")
L          <- import("chainer.links")
training   <- chainer$training
extensions <- training$extensions


# Network definition -----------------------------------------------------------
MLP <- pyClass('MLP', inherit = tuple(chainer$Chain),
  member = dict(
    `__init__` = function(self, n_units, n_out) {
      py$super(MLP, self)$`__init__`(
        # the size of the inputs to each layer will be inferred
        l1 = L$Linear(NULL, n_units),  # n_in -> n_units
        l2 = L$Linear(NULL, n_units),  # n_units -> n_units
        l3 = L$Linear(NULL, n_out)    # n_units -> n_out
      )
    },
    `__call__` = function(self, x) {
      h1 <- F$relu(self$l1(x))
      h2 <- F$relu(self$l2(h1))
      return(self$l3(h2))
    }
  )
)


# main -------------------------------------------------------------------------
args <- list(
  batchsize = 100L,    # Number of images in each mini-batch
  epoch     = 10L,     # Number of sweeps over the dataset to train
  unit      = 1000L,   # Number of units
  out       = 'result' # Directory to output the result
)

# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model <- L$Classifier(MLP(args$unit, 10L))

# Setup an optimizer
optimizer <- chainer$optimizers$Adam()
optimizer$setup(model)

# Load the MNIST dataset
mnist_data <- chainer$datasets$get_mnist()
train <- mnist_data[[1]]
test  <- mnist_data[[2]]

train_iter <- chainer$iterators$SerialIterator(train, args$batchsize)
test_iter <- chainer$iterators$SerialIterator(test, args$batchsize,
                                              FALSE, FALSE)

# Set up a trainer
updater <- training$StandardUpdater(train_iter, optimizer)
trainer <- training$Trainer(updater, tuple(args$epoch, 'epoch'), out=args$out)

# Evaluate the model with the test dataset for each epoch
trainer$extend(extensions$Evaluator(test_iter, model))

# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer$extend(extensions$dump_graph('main/loss'))

# Take a snapshot at each epoch
trainer$extend(extensions$snapshot(), trigger=tuple(args$epoch, 'epoch'))

# Write a log of evaluation statistics for each epoch
trainer$extend(extensions$LogReport())

# Save two plot images to the result dir
trainer$extend(
  extensions$PlotReport(c('main/loss', 'validation/main/loss'), 'epoch',
                        file_name='loss.png'))
trainer$extend(
  extensions$PlotReport(c('main/accuracy', 'validation/main/accuracy'),
                        'epoch', file_name='accuracy.png'))

# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer$extend(extensions$PrintReport(
  c('epoch', 'main/loss', 'validation/main/loss',
   'main/accuracy', 'validation/main/accuracy', 'elapsed_time')))

# Print a progress bar to stdout
trainer$extend(extensions$ProgressBar())

# Run the training
trainer$run()
4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?