0

More than 5 years have passed since last update.

Posted at

# 概要

jsdoでtensorflow.jsやってみた。
さらに、dqn見つけたので、やってみた。

ランダム

dqn

# サンプルコード

``````var Cartpole = function() {
this.canvas = document.getElementById('canvas');
this.ctx = this.canvas.getContext('2d');
this.gravity = 9.8;
this.masscart = 1.0;
this.masspole = 0.1;
this.total_mass = (this.masspole + this.masscart);
this.length = 0.5;
this.polemass_length = (this.masspole * self.length);
this.force_mag = 10.0;
this.tau = 0.02;
this.theta_threshold_radians = 12 * 2 * Math.PI / 360;
this.x_threshold = 2.4;
this.state = {
x: 0,
x_dot: 0,
theta: 0,
theta_dot: 0
};
this.steps_beyond_done = null;
};
Cartpole.prototype.reset = function() {
this.state = {
x: Math.random() * 0.4 - 0.2,
x_dot: Math.random() * 0.4 - 0.2,
theta: Math.random() * 0.4 - 0.2,
theta_dot: Math.random() * 0.4 - 0.2
};
this.steps_beyond_done = null;
return this.state;
}
Cartpole.prototype.step = function(action) {
var state = this.state;
var x = state.x;
var x_dot = state.x_dot;
var theta = state.theta;
var theta_dot = state.theta_dot;
var force = this.force_mag;
if (action < 1) force = -this.force_mag;
var costheta = Math.cos(theta);
var sintheta = Math.sin(theta);
var temp = (force + this.polemass_length * theta_dot * theta_dot * sintheta) / this.total_mass;
var thetaacc = (this.gravity * sintheta - costheta * temp) / (this.length * (4.0 / 3.0 - this.masspole * costheta * costheta / this.total_mass));
var xacc = temp - this.polemass_length * thetaacc * costheta / this.total_mass;
x = x + this.tau * x_dot
x_dot = x_dot + this.tau * xacc;
theta = theta + this.tau * theta_dot;
theta_dot = theta_dot + this.tau * thetaacc;
this.state = {
x: x,
x_dot: x_dot,
theta: theta,
theta_dot: theta_dot
};
var reward;
var done = x < -this.x_threshold || x > this.x_threshold || theta < -this.theta_threshold_radians || theta > this.theta_threshold_radians;
if (!done)
{
reward = 1.0;
}
else if (this.steps_beyond_done == null)
{
this.steps_beyond_done = 0;
reward = 1.0;
}
else
{
if (this.steps_beyond_done == 0)
{
}
this.steps_beyond_done += 1;
reward = 0.0;
}
return [this.state, reward, done];
}
Cartpole.prototype.render = function() {
var scale = 100;
var x = this.state.x;
var theta = this.state.theta;
var cartx = x * scale + this.canvas.width / 2.0;
this.ctx.clearRect(0, 0, this.canvas.width, this.canvas.height);
this.ctx.beginPath();
this.ctx.strokeStyle = '#ffa500';
this.ctx.lineWidth = 8;
this.ctx.moveTo(cartx - 20, 200);
this.ctx.lineTo(cartx + 20, 200);
this.ctx.stroke();
this.ctx.beginPath();
this.ctx.strokeStyle = 'black';
this.ctx.lineWidth = 2;
this.ctx.moveTo(cartx, 200);
this.ctx.lineTo(cartx + Math.sin(theta) * 100, 200 - Math.cos(theta) * 100);
this.ctx.stroke();
return 0;
}
Cartpole.prototype.getStateTensor = function() {
return tf.tensor2d([[this.state.x, this.state.x_dot, this.state.theta, this.state.theta_dot]]);
}
var env = new Cartpole();
class PolicyNetwork {
constructor(hiddenLayerSizesOrModel) {
if (hiddenLayerSizesOrModel instanceof tf.Model)
{
this.model = hiddenLayerSizesOrModel;
}
else
{
this.createModel(hiddenLayerSizesOrModel);
}
}
createModel(hiddenLayerSizes) {
if (!Array.isArray(hiddenLayerSizes))
{
hiddenLayerSizes = [hiddenLayerSizes];
}
this.model = tf.sequential();
hiddenLayerSizes.forEach((hiddenLayerSize, i) => {
units: hiddenLayerSize,
activation: 'elu',
inputShape: i === 0 ? [4] : undefined
}));
});
units: 1
}));
}
train(env, optimizer, discountRate, numGames, maxStepsPerGame) {
const allRewards = [];
const gameSteps = [];
for (let i = 0; i < numGames; ++i)
{
const gameRewards = [];
var observation = env.reset();
for (let j = 0; j < maxStepsPerGame; ++j)
{

env.render();
const gradients = tf.tidy(() => {
const inputTensor = env.getStateTensor();
});
const action = this.currentActions_[0];
var res = env.step(action);
var observation = res[0];
var reward = res[1];
var isDone = res[2];
if (isDone)
{
gameRewards.push(0);
break;
}
else
{
gameRewards.push(1);
}
}
gameSteps.push(gameRewards.length);
allRewards.push(gameRewards);
tf.nextFrame();
}
tf.tidy(() => {
const normalizedRewards = discountAndNormalizeRewards(allRewards, discountRate);
});
return gameSteps;
}
const f = () => tf.tidy(() => {
const [logits, actions] = this.getLogitsAndActions(inputTensor);
this.currentActions_ = actions.dataSync();
const labels = tf.sub(1, tf.tensor2d(this.currentActions_, actions.shape));
return tf.losses.sigmoidCrossEntropy(labels, logits).asScalar();
});
}
getCurrentActions() {
return this.currentActions_;
}
getLogitsAndActions(inputs) {
return tf.tidy(() => {
const logits = this.model.predict(inputs);
const leftProb = tf.sigmoid(logits);
const leftRightProbs = tf.concat([leftProb, tf.sub(1, leftProb)], 1);
const actions = tf.multinomial(leftRightProbs, 1, null, true);
return [logits, actions];
});
}
getActions(inputs) {
return this.getLogitsAndActions(inputs)[1].dataSync();
}
{
if (key in record)
{
}
else
{
}
}
}
}
const MODEL_SAVE_PATH_ = 'indexeddb://cart-pole-v1';
class SaveablePolicyNetwork extends PolicyNetwork {
constructor(hiddenLayerSizesOrModel) {
super(hiddenLayerSizesOrModel);
}
saveModel() {
return this.model.save(MODEL_SAVE_PATH_);
}
const modelsInfo = tf.io.listModels();
if (MODEL_SAVE_PATH_ in modelsInfo)
{
return new SaveablePolicyNetwork(model);
}
else
{
throw new Error(`Cannot find model at \${MODEL_SAVE_PATH_}.`);
}
}
checkStoredModelStatus() {
const modelsInfo = tf.io.listModels();
return modelsInfo[MODEL_SAVE_PATH_];
}
removeModel() {
return tf.io.removeModel(MODEL_SAVE_PATH_);
}
hiddenLayerSizes() {
const sizes = [];
for (let i = 0; i < this.model.layers.length - 1; ++i)
{
sizes.push(this.model.layers[i].units);
}
return sizes.length === 1 ? sizes[0] : sizes;
}
}
function discountRewards(rewards, discountRate) {
const discountedBuffer = tf.buffer([rewards.length]);
let prev = 0;
for (let i = rewards.length - 1; i >= 0; --i)
{
const current = discountRate * prev + rewards[i];
discountedBuffer.set(current, i);
prev = current;
}
return discountedBuffer.toTensor();
}
function discountAndNormalizeRewards(rewardSequences, discountRate) {
return tf.tidy(() => {
const discounted = [];
for (const sequence of rewardSequences)
{
discounted.push(discountRewards(sequence, discountRate))
}
const concatenated = tf.concat(discounted);
const mean = tf.mean(concatenated);
const std = tf.sqrt(tf.mean(tf.square(concatenated.sub(mean))));
const normalized = discounted.map(rs => rs.sub(mean).div(std));
return normalized;
});
}
return tf.tidy(() => {
{
const expandedDims = [];
for (let i = 0; i < varGradients[0].rank - 1; ++i)
{
expandedDims.push(1);
}
const reshapedNormalizedRewards = normalizedRewards.map(rs => rs.reshape(rs.shape.concat(expandedDims)));
for (let g = 0; g < varGradients.length; ++g)
{
}
});
}
});
}
var policyNet = new PolicyNetwork(20);
var steps = policyNet.train(env, optimizer, 0.9, 100, 200);
var canvas = document.getElementById('canvas');
var ctx = this.canvas.getContext('2d');
ctx.beginPath();
ctx.strokeStyle = 'black';
ctx.lineWidth = 1;
ctx.moveTo(0, 300);
for (var i = 0; i < 100; i++)
{
ctx.lineTo(i * 10, 300 - steps[i] * 1.5);

}
ctx.stroke();

``````

# 成果物

Register as a new user and use Qiita more conveniently

1. You get articles that match your needs
2. You can efficiently read back useful information
3. You can use dark theme
What you can do with signing up
0