#概要
jsdoでtensorflow.jsやってみた。
さらに、dqn見つけたので、やってみた。
#写真
#サンプルコード
var Cartpole = function() {
this.canvas = document.getElementById('canvas');
this.ctx = this.canvas.getContext('2d');
this.gravity = 9.8;
this.masscart = 1.0;
this.masspole = 0.1;
this.total_mass = (this.masspole + this.masscart);
this.length = 0.5;
this.polemass_length = (this.masspole * self.length);
this.force_mag = 10.0;
this.tau = 0.02;
this.theta_threshold_radians = 12 * 2 * Math.PI / 360;
this.x_threshold = 2.4;
this.state = {
x: 0,
x_dot: 0,
theta: 0,
theta_dot: 0
};
this.steps_beyond_done = null;
};
Cartpole.prototype.reset = function() {
this.state = {
x: Math.random() * 0.4 - 0.2,
x_dot: Math.random() * 0.4 - 0.2,
theta: Math.random() * 0.4 - 0.2,
theta_dot: Math.random() * 0.4 - 0.2
};
this.steps_beyond_done = null;
return this.state;
}
Cartpole.prototype.step = function(action) {
var state = this.state;
var x = state.x;
var x_dot = state.x_dot;
var theta = state.theta;
var theta_dot = state.theta_dot;
var force = this.force_mag;
if (action < 1) force = -this.force_mag;
var costheta = Math.cos(theta);
var sintheta = Math.sin(theta);
var temp = (force + this.polemass_length * theta_dot * theta_dot * sintheta) / this.total_mass;
var thetaacc = (this.gravity * sintheta - costheta * temp) / (this.length * (4.0 / 3.0 - this.masspole * costheta * costheta / this.total_mass));
var xacc = temp - this.polemass_length * thetaacc * costheta / this.total_mass;
x = x + this.tau * x_dot
x_dot = x_dot + this.tau * xacc;
theta = theta + this.tau * theta_dot;
theta_dot = theta_dot + this.tau * thetaacc;
this.state = {
x: x,
x_dot: x_dot,
theta: theta,
theta_dot: theta_dot
};
var reward;
var done = x < -this.x_threshold || x > this.x_threshold || theta < -this.theta_threshold_radians || theta > this.theta_threshold_radians;
if (!done)
{
reward = 1.0;
}
else if (this.steps_beyond_done == null)
{
this.steps_beyond_done = 0;
reward = 1.0;
}
else
{
if (this.steps_beyond_done == 0)
{
alert("err");
}
this.steps_beyond_done += 1;
reward = 0.0;
}
return [this.state, reward, done];
}
Cartpole.prototype.render = function() {
var scale = 100;
var x = this.state.x;
var theta = this.state.theta;
var cartx = x * scale + this.canvas.width / 2.0;
this.ctx.clearRect(0, 0, this.canvas.width, this.canvas.height);
this.ctx.beginPath();
this.ctx.strokeStyle = '#ffa500';
this.ctx.lineWidth = 8;
this.ctx.moveTo(cartx - 20, 200);
this.ctx.lineTo(cartx + 20, 200);
this.ctx.stroke();
this.ctx.beginPath();
this.ctx.strokeStyle = 'black';
this.ctx.lineWidth = 2;
this.ctx.moveTo(cartx, 200);
this.ctx.lineTo(cartx + Math.sin(theta) * 100, 200 - Math.cos(theta) * 100);
this.ctx.stroke();
return 0;
}
Cartpole.prototype.getStateTensor = function() {
return tf.tensor2d([[this.state.x, this.state.x_dot, this.state.theta, this.state.theta_dot]]);
}
var env = new Cartpole();
class PolicyNetwork {
constructor(hiddenLayerSizesOrModel) {
if (hiddenLayerSizesOrModel instanceof tf.Model)
{
this.model = hiddenLayerSizesOrModel;
}
else
{
this.createModel(hiddenLayerSizesOrModel);
}
}
createModel(hiddenLayerSizes) {
if (!Array.isArray(hiddenLayerSizes))
{
hiddenLayerSizes = [hiddenLayerSizes];
}
this.model = tf.sequential();
hiddenLayerSizes.forEach((hiddenLayerSize, i) => {
this.model.add(tf.layers.dense({
units: hiddenLayerSize,
activation: 'elu',
inputShape: i === 0 ? [4] : undefined
}));
});
this.model.add(tf.layers.dense({
units: 1
}));
}
train(env, optimizer, discountRate, numGames, maxStepsPerGame) {
const allGradients = [];
const allRewards = [];
const gameSteps = [];
for (let i = 0; i < numGames; ++i)
{
const gameRewards = [];
const gameGradients = [];
var observation = env.reset();
for (let j = 0; j < maxStepsPerGame; ++j)
{
env.render();
const gradients = tf.tidy(() => {
const inputTensor = env.getStateTensor();
return this.getGradientsAndSaveActions(inputTensor).grads;
});
this.pushGradients(gameGradients, gradients);
const action = this.currentActions_[0];
var res = env.step(action);
var observation = res[0];
var reward = res[1];
var isDone = res[2];
if (isDone)
{
gameRewards.push(0);
break;
}
else
{
gameRewards.push(1);
}
}
gameSteps.push(gameRewards.length);
this.pushGradients(allGradients, gameGradients);
allRewards.push(gameRewards);
tf.nextFrame();
}
tf.tidy(() => {
const normalizedRewards = discountAndNormalizeRewards(allRewards, discountRate);
optimizer.applyGradients(scaleAndAverageGradients(allGradients, normalizedRewards));
});
tf.dispose(allGradients);
return gameSteps;
}
getGradientsAndSaveActions(inputTensor) {
const f = () => tf.tidy(() => {
const [logits, actions] = this.getLogitsAndActions(inputTensor);
this.currentActions_ = actions.dataSync();
const labels = tf.sub(1, tf.tensor2d(this.currentActions_, actions.shape));
return tf.losses.sigmoidCrossEntropy(labels, logits).asScalar();
});
return tf.variableGrads(f);
}
getCurrentActions() {
return this.currentActions_;
}
getLogitsAndActions(inputs) {
return tf.tidy(() => {
const logits = this.model.predict(inputs);
const leftProb = tf.sigmoid(logits);
const leftRightProbs = tf.concat([leftProb, tf.sub(1, leftProb)], 1);
const actions = tf.multinomial(leftRightProbs, 1, null, true);
return [logits, actions];
});
}
getActions(inputs) {
return this.getLogitsAndActions(inputs)[1].dataSync();
}
pushGradients(record, gradients) {
for (const key in gradients)
{
if (key in record)
{
record[key].push(gradients[key]);
}
else
{
record[key] = [gradients[key]];
}
}
}
}
const MODEL_SAVE_PATH_ = 'indexeddb://cart-pole-v1';
class SaveablePolicyNetwork extends PolicyNetwork {
constructor(hiddenLayerSizesOrModel) {
super(hiddenLayerSizesOrModel);
}
saveModel() {
return this.model.save(MODEL_SAVE_PATH_);
}
loadModel() {
const modelsInfo = tf.io.listModels();
if (MODEL_SAVE_PATH_ in modelsInfo)
{
const model = tf.loadModel(MODEL_SAVE_PATH_);
return new SaveablePolicyNetwork(model);
}
else
{
throw new Error(`Cannot find model at ${MODEL_SAVE_PATH_}.`);
}
}
checkStoredModelStatus() {
const modelsInfo = tf.io.listModels();
return modelsInfo[MODEL_SAVE_PATH_];
}
removeModel() {
return tf.io.removeModel(MODEL_SAVE_PATH_);
}
hiddenLayerSizes() {
const sizes = [];
for (let i = 0; i < this.model.layers.length - 1; ++i)
{
sizes.push(this.model.layers[i].units);
}
return sizes.length === 1 ? sizes[0] : sizes;
}
}
function discountRewards(rewards, discountRate) {
const discountedBuffer = tf.buffer([rewards.length]);
let prev = 0;
for (let i = rewards.length - 1; i >= 0; --i)
{
const current = discountRate * prev + rewards[i];
discountedBuffer.set(current, i);
prev = current;
}
return discountedBuffer.toTensor();
}
function discountAndNormalizeRewards(rewardSequences, discountRate) {
return tf.tidy(() => {
const discounted = [];
for (const sequence of rewardSequences)
{
discounted.push(discountRewards(sequence, discountRate))
}
const concatenated = tf.concat(discounted);
const mean = tf.mean(concatenated);
const std = tf.sqrt(tf.mean(tf.square(concatenated.sub(mean))));
const normalized = discounted.map(rs => rs.sub(mean).div(std));
return normalized;
});
}
function scaleAndAverageGradients(allGradients, normalizedRewards) {
return tf.tidy(() => {
const gradients = {};
for (const varName in allGradients)
{
gradients[varName] = tf.tidy(() => {
const varGradients = allGradients[varName].map(varGameGradients => tf.stack(varGameGradients));
const expandedDims = [];
for (let i = 0; i < varGradients[0].rank - 1; ++i)
{
expandedDims.push(1);
}
const reshapedNormalizedRewards = normalizedRewards.map(rs => rs.reshape(rs.shape.concat(expandedDims)));
for (let g = 0; g < varGradients.length; ++g)
{
varGradients[g] = varGradients[g].mul(reshapedNormalizedRewards[g]);
}
return tf.mean(tf.concat(varGradients, 0), 0);
});
}
return gradients;
});
}
const optimizer = tf.train.adam(0.01);
var policyNet = new PolicyNetwork(20);
var steps = policyNet.train(env, optimizer, 0.9, 100, 200);
var canvas = document.getElementById('canvas');
var ctx = this.canvas.getContext('2d');
ctx.beginPath();
ctx.strokeStyle = 'black';
ctx.lineWidth = 1;
ctx.moveTo(0, 300);
for (var i = 0; i < 100; i++)
{
ctx.lineTo(i * 10, 300 - steps[i] * 1.5);
}
ctx.stroke();
#成果物
以上。