#概要
jsdoでtensorflow.jsやってみた。
keras風、高級APIを使わないで、coreとか言われるレベルでやってみた。
九九問題、やってみた。
#写真
#学習
バッチ数: 81
input: 8
隠れ層: 1
ユニット: 60
活性化関数: tanh
output: 7
活性化関数: softmax
オプチマイザー: adam
ロス: softmaxCrossEntropy
エポック数: 9000
#サンプルコード
function tob(i, j) {
var ary = new Array;
var v = (i + 1) * (j + 1);
var b = v & 1;
if (b > 0)
{
ary.push(1);
}
else
{
ary.push(0);
}
for (var k = 0 ; k < 6; k++)
{
var b = v & (2 << k);
if (b > 0)
{
ary.push(1);
}
else
{
ary.push(0);
}
}
return ary;
}
const buffer2 = tf.buffer([81, 7]);
for (var i = 0; i < 9; i++)
{
for (var j = 0; j < 9; j++)
{
var l = i * 9 + j;
var x = tob(i, j);
for (var k = 0; k < 7; k++)
{
buffer2.set(x[k], l, k);
}
}
}
const yt = buffer2.toTensor();
function toa(i, j) {
var ary = new Array;
var v = i + 1;
var b = v & 1;
if (b > 0)
{
ary.push(1);
}
else
{
ary.push(0);
}
for (var k = 0 ; k < 3; k++)
{
var b = v & (2 << k);
if (b > 0)
{
ary.push(1);
}
else
{
ary.push(0);
}
}
var v = j + 1;
var b = v & 1;
if (b > 0)
{
ary.push(1);
}
else
{
ary.push(0);
}
for (var k = 0 ; k < 3; k++)
{
var b = v & (2 << k);
if (b > 0)
{
ary.push(1);
}
else
{
ary.push(0);
}
}
return ary;
}
const buffer = tf.buffer([81, 8]);
for (var i = 0; i < 9; i++)
{
for (var j = 0; j < 9; j++)
{
var l = i * 9 + j;
var x = toa(i, j);
for (var k = 0; k < 8; k++)
{
buffer.set(x[k], l, k);
}
}
}
const xt = buffer.toTensor();
var num = 60;
const w1 = tf.variable(tf.randomNormal([8, num]));
const b1 = tf.variable(tf.randomNormal([num]));
const w3 = tf.variable(tf.randomNormal([num, 7]));
const b3 = tf.variable(tf.randomNormal([7]));
function func(x) {
const h1 = tf.tanh(x.matMul(w1).add(b1));
return tf.softmax(h1.matMul(w3).add(b3));
}
function loss(pred, ypred) {
return tf.losses.softmaxCrossEntropy(pred, ypred).mean();
}
const optimizer = tf.train.adam(0.01);
var cc;
for (let i = 0; i < 9001; i++)
{
const cost = optimizer.minimize(() => loss(func(xt), yt), true);
cc = cost;
}
//document.write(func(xt));
var pre = func(xt);
var p = pre.dataSync();
var col;
var row;
document.write('<table>');
var lim = 0.15;
for (row = 0; row < 10; row++)
{
document.write('<tr>');
for (col = 0; col < 10; col++)
{
if (col === 0 && row === 0)
{
document.write('<th> <\/th>');
}
else if (col === 0 && row !== 0)
{
document.write('<th>' + row + '<\/th>');
}
else if (row === 0)
{
document.write('<th>' + col + '<\/th>');
}
else
{
var i = (row - 1) * 9 + (col - 1);
var v = 0;
if (p[i * 7 + 0] > lim) v += 1;
if (p[i * 7 + 1] > lim) v += 2;
if (p[i * 7 + 2] > lim) v += 4;
if (p[i * 7 + 3] > lim) v += 8;
if (p[i * 7 + 4] > lim) v += 16;
if (p[i * 7 + 5] > lim) v += 32;
if (p[i * 7 + 6] > lim) v += 64;
document.write('<td>' + v + '<\/td>');
}
}
document.write('<\/tr>');
}
document.write('<\/table>');
#成果物
以上。