LoginSignup
0
0

More than 5 years have passed since last update.

jsdoでtensorflow.js その17

Posted at

概要

jsdoでtensorflow.jsやってみた。
さらに、dqn見つけたので、やってみた。

写真

ランダム

dqn
image

サンプルコード

var Cartpole = function() {
    this.canvas = document.getElementById('canvas');
    this.ctx = this.canvas.getContext('2d');
    this.gravity = 9.8;
    this.masscart = 1.0;
    this.masspole = 0.1;
    this.total_mass = (this.masspole + this.masscart);
    this.length = 0.5;
    this.polemass_length = (this.masspole * self.length);
    this.force_mag = 10.0;
    this.tau = 0.02;
    this.theta_threshold_radians = 12 * 2 * Math.PI / 360;
    this.x_threshold = 2.4;
    this.state = {
       x: 0,
       x_dot: 0,
       theta: 0,
       theta_dot: 0
    };
    this.steps_beyond_done = null;
};
Cartpole.prototype.reset = function() {
    this.state = {
       x: Math.random() * 0.4 - 0.2,
       x_dot: Math.random() * 0.4 - 0.2,
       theta: Math.random() * 0.4 - 0.2,
       theta_dot: Math.random() * 0.4 - 0.2
    };
    this.steps_beyond_done = null;
    return this.state;
}
Cartpole.prototype.step = function(action) {
    var state = this.state;
    var x = state.x; 
    var x_dot = state.x_dot;
    var theta = state.theta;
    var theta_dot = state.theta_dot;
    var force = this.force_mag;
    if (action < 1) force = -this.force_mag;
    var costheta = Math.cos(theta);
    var sintheta = Math.sin(theta);
    var temp = (force + this.polemass_length * theta_dot * theta_dot * sintheta) / this.total_mass;
    var thetaacc = (this.gravity * sintheta - costheta * temp) / (this.length * (4.0 / 3.0 - this.masspole * costheta * costheta / this.total_mass));
    var xacc = temp - this.polemass_length * thetaacc * costheta / this.total_mass;
    x = x + this.tau * x_dot
    x_dot = x_dot + this.tau * xacc;
    theta = theta + this.tau * theta_dot;
    theta_dot = theta_dot + this.tau * thetaacc;
    this.state = {
        x: x, 
        x_dot: x_dot, 
        theta: theta, 
        theta_dot: theta_dot
    };
    var reward;
    var done = x < -this.x_threshold || x > this.x_threshold || theta < -this.theta_threshold_radians || theta > this.theta_threshold_radians;
    if (!done) 
    {
        reward = 1.0;
    }
    else if (this.steps_beyond_done == null)
    {
        this.steps_beyond_done = 0;
        reward = 1.0;
    }
    else
    {
        if (this.steps_beyond_done == 0) 
        {
            alert("err");
        }
        this.steps_beyond_done += 1;
        reward = 0.0;
    }
    return [this.state, reward, done];
}
Cartpole.prototype.render = function() {
    var scale = 100;
    var x = this.state.x;
    var theta = this.state.theta;
    var cartx = x * scale + this.canvas.width / 2.0;
    this.ctx.clearRect(0, 0, this.canvas.width, this.canvas.height);
    this.ctx.beginPath();
    this.ctx.strokeStyle = '#ffa500';
    this.ctx.lineWidth = 8;
    this.ctx.moveTo(cartx - 20, 200);
    this.ctx.lineTo(cartx + 20, 200);
    this.ctx.stroke();
    this.ctx.beginPath();
    this.ctx.strokeStyle = 'black';
    this.ctx.lineWidth = 2;
    this.ctx.moveTo(cartx, 200);
    this.ctx.lineTo(cartx + Math.sin(theta) * 100, 200 - Math.cos(theta) * 100);
    this.ctx.stroke();
    return 0;
}
Cartpole.prototype.getStateTensor = function() {
    return tf.tensor2d([[this.state.x, this.state.x_dot, this.state.theta, this.state.theta_dot]]);
}
var env = new Cartpole();
class PolicyNetwork {
    constructor(hiddenLayerSizesOrModel) {
        if (hiddenLayerSizesOrModel instanceof tf.Model) 
        {
            this.model = hiddenLayerSizesOrModel;
        }
        else
        {
            this.createModel(hiddenLayerSizesOrModel);
        }
    }
    createModel(hiddenLayerSizes) {
        if (!Array.isArray(hiddenLayerSizes)) 
        {
            hiddenLayerSizes = [hiddenLayerSizes];
        }
        this.model = tf.sequential();
        hiddenLayerSizes.forEach((hiddenLayerSize, i) => {
            this.model.add(tf.layers.dense({
                units: hiddenLayerSize,
                activation: 'elu',
                inputShape: i === 0 ? [4] : undefined
            }));
        });
        this.model.add(tf.layers.dense({
            units: 1
        }));
    }
    train(env, optimizer, discountRate, numGames, maxStepsPerGame) {
        const allGradients = [];
        const allRewards = [];
        const gameSteps = [];
        for (let i = 0; i < numGames; ++i) 
        {
            const gameRewards = [];
            const gameGradients = [];
            var observation = env.reset();
            for (let j = 0; j < maxStepsPerGame; ++j)
            {

                env.render();         
                const gradients = tf.tidy(() => {
                    const inputTensor = env.getStateTensor();
                    return this.getGradientsAndSaveActions(inputTensor).grads;
                });
                this.pushGradients(gameGradients, gradients);
                const action = this.currentActions_[0];
                var res = env.step(action);
                var observation = res[0];
                var reward = res[1];
                var isDone = res[2];            
                if (isDone) 
                {
                    gameRewards.push(0);
                    break;
                }
                else 
                {
                    gameRewards.push(1);
                }
            }
            gameSteps.push(gameRewards.length);
            this.pushGradients(allGradients, gameGradients);
            allRewards.push(gameRewards);
            tf.nextFrame();
        }
        tf.tidy(() => {
            const normalizedRewards = discountAndNormalizeRewards(allRewards, discountRate);
            optimizer.applyGradients(scaleAndAverageGradients(allGradients, normalizedRewards));
        });
        tf.dispose(allGradients);
        return gameSteps;
    }
    getGradientsAndSaveActions(inputTensor) {
        const f = () => tf.tidy(() => {
            const [logits, actions] = this.getLogitsAndActions(inputTensor);
            this.currentActions_ = actions.dataSync();
            const labels = tf.sub(1, tf.tensor2d(this.currentActions_, actions.shape));
            return tf.losses.sigmoidCrossEntropy(labels, logits).asScalar();
        });
        return tf.variableGrads(f);
    }
    getCurrentActions() {
        return this.currentActions_;
    }
    getLogitsAndActions(inputs) {
        return tf.tidy(() => {
            const logits = this.model.predict(inputs);
            const leftProb = tf.sigmoid(logits);
            const leftRightProbs = tf.concat([leftProb, tf.sub(1, leftProb)], 1);
            const actions = tf.multinomial(leftRightProbs, 1, null, true);
            return [logits, actions];
        });
    }
    getActions(inputs) {
        return this.getLogitsAndActions(inputs)[1].dataSync();
    }
    pushGradients(record, gradients) {
        for (const key in gradients) 
        {
            if (key in record)
            {
                record[key].push(gradients[key]);
            }
            else
            {
                record[key] = [gradients[key]];
            }
        }
    }
}
const MODEL_SAVE_PATH_ = 'indexeddb://cart-pole-v1';
class SaveablePolicyNetwork extends PolicyNetwork {
    constructor(hiddenLayerSizesOrModel) {
        super(hiddenLayerSizesOrModel);
    }
    saveModel() {
        return this.model.save(MODEL_SAVE_PATH_);
    }
    loadModel() {
        const modelsInfo = tf.io.listModels();
        if (MODEL_SAVE_PATH_ in modelsInfo)
        {
            const model = tf.loadModel(MODEL_SAVE_PATH_);
            return new SaveablePolicyNetwork(model);
        }
        else
        {
            throw new Error(`Cannot find model at ${MODEL_SAVE_PATH_}.`);
        }
    }
    checkStoredModelStatus() {
        const modelsInfo = tf.io.listModels();
        return modelsInfo[MODEL_SAVE_PATH_];
    }
    removeModel() {
        return tf.io.removeModel(MODEL_SAVE_PATH_);
    }
    hiddenLayerSizes() {
        const sizes = [];
        for (let i = 0; i < this.model.layers.length - 1; ++i)
        {
            sizes.push(this.model.layers[i].units);
        }
        return sizes.length === 1 ? sizes[0] : sizes;
    }
}
function discountRewards(rewards, discountRate) {
    const discountedBuffer = tf.buffer([rewards.length]);
    let prev = 0;
    for (let i = rewards.length - 1; i >= 0; --i)
    {
        const current = discountRate * prev + rewards[i];
        discountedBuffer.set(current, i);
        prev = current;
    }
    return discountedBuffer.toTensor();
}
function discountAndNormalizeRewards(rewardSequences, discountRate) {
    return tf.tidy(() => {
        const discounted = [];
        for (const sequence of rewardSequences) 
        {
            discounted.push(discountRewards(sequence, discountRate))
        }
        const concatenated = tf.concat(discounted);
        const mean = tf.mean(concatenated);
        const std = tf.sqrt(tf.mean(tf.square(concatenated.sub(mean))));
        const normalized = discounted.map(rs => rs.sub(mean).div(std));
        return normalized;
    });
}
function scaleAndAverageGradients(allGradients, normalizedRewards) {
    return tf.tidy(() => {
        const gradients = {};
        for (const varName in allGradients)
        {
            gradients[varName] = tf.tidy(() => {
                const varGradients = allGradients[varName].map(varGameGradients => tf.stack(varGameGradients));
                const expandedDims = [];
                for (let i = 0; i < varGradients[0].rank - 1; ++i)
                {
                    expandedDims.push(1);
                }
                const reshapedNormalizedRewards = normalizedRewards.map(rs => rs.reshape(rs.shape.concat(expandedDims)));
                for (let g = 0; g < varGradients.length; ++g)
                {
                    varGradients[g] = varGradients[g].mul(reshapedNormalizedRewards[g]);
                }
                return tf.mean(tf.concat(varGradients, 0), 0);
            });
        }
        return gradients;
    });
}
const optimizer = tf.train.adam(0.01);
var policyNet = new PolicyNetwork(20);
var steps = policyNet.train(env, optimizer, 0.9, 100, 200);
var canvas = document.getElementById('canvas');
var ctx = this.canvas.getContext('2d');
ctx.beginPath();
ctx.strokeStyle = 'black';
ctx.lineWidth = 1;
ctx.moveTo(0, 300);
for (var i = 0; i < 100; i++)
{
    ctx.lineTo(i * 10, 300 - steps[i] * 1.5);

}
ctx.stroke();



成果物

以上。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0