概要
kelpnetの作法を調べてみた。
九九やってみた。
結果
サンプルコード
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using KelpNet;
namespace ConsoleApp5
{
class Program
{
const int EPOCH = 2001;
const int N = 100;
static void Main(string[] args)
{
Real[][] trainData = new Real[N][];
Real[][] trainLabel = new Real[N][];
int k;
for (int i = 0; i < 10; i++)
{
for (int j = 0; j < 10; j++)
{
k = i * 10 + j;
int a = (i >> 3) % 2;
int b = (i >> 2) % 2;
int c = (i >> 1) % 2;
int d = i % 2;
int e = (j >> 3) % 2;
int f = (j >> 2) % 2;
int g = (j >> 1) % 2;
int h = j % 2;
trainData[k] = new[] {
(Real)a,
(Real)b,
(Real)c,
(Real)d,
(Real)e,
(Real)f,
(Real)g,
(Real)h
};
}
}
int l;
for (int i = 0; i < 10; i++)
{
for (int j = 0; j < 10; j++)
{
k = i * 10 + j;
l = i * j;
int a = (l >> 6) % 2;
int b = (l >> 5) % 2;
int c = (l >> 4) % 2;
int d = (l >> 3) % 2;
int e = (l >> 2) % 2;
int f = (l >> 1) % 2;
int g = l % 2;
trainLabel[k] = new[] {
(Real)a,
(Real)b,
(Real)c,
(Real)d,
(Real)e,
(Real)f,
(Real)g
};
}
}
FunctionStack nn = new FunctionStack(new Linear(8, 40, name: "l1 Linear"), new TanhActivation(name: "act"), new Linear(40, 7, name: "l2 Linear"));
nn.SetOptimizer(new MomentumSGD());
Console.WriteLine("Train Start...");
for (int i = 0; i < EPOCH; i++)
{
Real loss = 0;
for (int j = 0; j < N; j++)
{
loss += Trainer.Train(nn, trainData[j], trainLabel[j], new MeanSquaredError());
}
if (i % 100 == 0)
{
Console.WriteLine("loss:" + loss / N);
}
}
Console.WriteLine("Test Start...");
Console.Write(" ");
for (int i = 1; i < 10; i++)
{
Console.Write(String.Format("{0, 2} ", i));
}
Console.WriteLine("");
for (int i = 1; i < 10; i++)
{
Console.Write(String.Format("{0, 2} ", i));
for (int j = 1; j < 10; j++)
{
k = i * 10 + j;
l = 0;
NdArray result = nn.Predict(trainData[k])[0];
if (result.Data[0] > 0.5)
{
l += 64;
}
if (result.Data[1] > 0.5)
{
l += 32;
}
if (result.Data[2] > 0.5)
{
l += 16;
}
if (result.Data[3] > 0.5)
{
l += 8;
}
if (result.Data[4] > 0.5)
{
l += 4;
}
if (result.Data[5] > 0.5)
{
l += 2;
}
if (result.Data[6] > 0.5)
{
l += 1;
}
Console.Write(String.Format("{0, 2} ", l));
}
Console.WriteLine("");
}
Console.WriteLine("Press any key to exit.");
Console.ReadKey();
}
}
}
以上。