# -------------------------------------------------------------------------------
# Use Chainer (http://chainer.org/) from R
# with reticulate package (https://github.com/rstudio/reticulate)
#
# original code in Python:
# https://github.com/pfnet/chainer/blob/master/examples/mnist/train_mnist.py
# -------------------------------------------------------------------------------
# setup R interface to Python --------------------------------------------------
library(reticulate)
py <- import("__builtin__")
main <- import("__main__")
# define python class generator ------------------------------------------------
# NOTE: class generation by directrly calling py$type() does not work (why?)
main <- py_run_string("
def pyClass(classname, inherit=None, member=None):
if inherit is None:
inherit = ()
if member is None:
member = dict()
return type(classname, inherit, member)
")
pyClass <- function(classname, inherit=NULL, member=NULL) {
.class <- .pyClass(classname, inherit, member)
class <- function (...) {
args <- list()
keywords <- list()
dots <- list(...)
names <- names(dots)
if (!is.null(names)) {
for (i in 1:length(dots)) {
name <- names[[i]]
if (nzchar(name))
if (is.null(dots[[i]]))
keywords[name] <- list(NULL)
else keywords[[name]] <- dots[[i]]
else if (is.null(dots[[i]]))
args[length(args) + 1] <- list(NULL)
else args[[length(args) + 1]] <- dots[[i]]
}
}
else {
args <- dots
}
result <- py_call(.class, args, keywords)
if (is.null(result))
invisible(result)
else result
}
structure(class, py_object=.class)
}
environment(pyClass)$.pyClass <- main$pyClass
py$delattr(main, "pyClass")
# chainer modules --------------------------------------------------------------
chainer <- import("chainer")
F <- import("chainer.functions")
L <- import("chainer.links")
training <- chainer$training
extensions <- training$extensions
# Network definition -----------------------------------------------------------
MLP <- pyClass('MLP', inherit = tuple(chainer$Chain),
member = dict(
`__init__` = function(self, n_units, n_out) {
py$super(MLP, self)$`__init__`(
# the size of the inputs to each layer will be inferred
l1 = L$Linear(NULL, n_units), # n_in -> n_units
l2 = L$Linear(NULL, n_units), # n_units -> n_units
l3 = L$Linear(NULL, n_out) # n_units -> n_out
)
},
`__call__` = function(self, x) {
h1 <- F$relu(self$l1(x))
h2 <- F$relu(self$l2(h1))
return(self$l3(h2))
}
)
)
# main -------------------------------------------------------------------------
args <- list(
batchsize = 100L, # Number of images in each mini-batch
epoch = 10L, # Number of sweeps over the dataset to train
unit = 1000L, # Number of units
out = 'result' # Directory to output the result
)
# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model <- L$Classifier(MLP(args$unit, 10L))
# Setup an optimizer
optimizer <- chainer$optimizers$Adam()
optimizer$setup(model)
# Load the MNIST dataset
mnist_data <- chainer$datasets$get_mnist()
train <- mnist_data[[1]]
test <- mnist_data[[2]]
train_iter <- chainer$iterators$SerialIterator(train, args$batchsize)
test_iter <- chainer$iterators$SerialIterator(test, args$batchsize,
FALSE, FALSE)
# Set up a trainer
updater <- training$StandardUpdater(train_iter, optimizer)
trainer <- training$Trainer(updater, tuple(args$epoch, 'epoch'), out=args$out)
# Evaluate the model with the test dataset for each epoch
trainer$extend(extensions$Evaluator(test_iter, model))
# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer$extend(extensions$dump_graph('main/loss'))
# Take a snapshot at each epoch
trainer$extend(extensions$snapshot(), trigger=tuple(args$epoch, 'epoch'))
# Write a log of evaluation statistics for each epoch
trainer$extend(extensions$LogReport())
# Save two plot images to the result dir
trainer$extend(
extensions$PlotReport(c('main/loss', 'validation/main/loss'), 'epoch',
file_name='loss.png'))
trainer$extend(
extensions$PlotReport(c('main/accuracy', 'validation/main/accuracy'),
'epoch', file_name='accuracy.png'))
# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer$extend(extensions$PrintReport(
c('epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time')))
# Print a progress bar to stdout
trainer$extend(extensions$ProgressBar())
# Run the training
trainer$run()
More than 5 years have passed since last update.
Register as a new user and use Qiita more conveniently
- You get articles that match your needs
- You can efficiently read back useful information
- You can use dark theme