LoginSignup
0
0

More than 5 years have passed since last update.

概要

lazuriteでnnやってみた。
xor問題、分類してみた。

結果

ok
0 0.31392 
10000 0.31176 
20000 0.02209 
30000 0.01108 
40000 0.00301 
50000 0.00362 
60000 0.00302 
70000 0.00154 
0,0 = 0.033 
0,1 = 0.951 
1,0 = 0.961 
1,1 = 0.047 

サンプルコード

#include "nn1_ide.h"        // Additional Header

int x[3];
int t[3];
double g[3];
double z[3];
double h[3][3];
double w[3][3];
#define sigmoid(a)        (1.0 / (1.0 + exp(-(a))))
void train()
{
    double alpha = 0.06;
    int i;
    int j;
    int k;
    long count;
    double error;
    for (j = 0; j <= 2; j++)
    {
        for (i = 0; i <= 2; i++)
        {
            w[j][i] = (double) ((rand() % 2000) - 1000) / 1000;
        }
    }
    for (j = 0; j <= 2; j++)
    {
        for (i = 0; i <= 2; i++)
        {
            h[j][i] = (double) ((rand() % 2000) - 1000) / 1000;
        }
    }
    for (count = 0; count < 70001; count++)
    {
        x[0] = 1;
        x[1] = rand() % 2;
        x[2] = rand() % 2;
        t[0] = (x[1] ^ x[2]);
        for (i = 0; i <= 2; i++)
        {
            double sum = 0;
            for (j = 0; j <= 2; j++)
            {
                sum += w[j][i] * x[j];
            }
            g[i] = sigmoid(sum);
        }
        for (i = 0; i < 1; i++)
        {
            double sum = 0;
            for (j = 0; j <= 2; j++)
            {
                sum += h[j][i] * g[j];
            }
            z[i] = sigmoid(sum);
        }
        for (j = 0; j < 1; j++)
        {
            for (i = 0; i <= 2; i++)
            {
                h[i][j] += alpha * g[i] * (t[j] - z[j]) * z[j] * (1.0 - z[j]);
            }
        }
        for (k = 0; k < 1; k++)
        {
            for (j = 0; j <= 2; j++)
            {
                double dj = (t[k] - z[k]) * z[k] * (1.0 - z[k]) * h[j][k] * g[j] * (1.0 - g[j]);
                for (i = 0; i <= 2; i++)
                {
                    w[i][j] += alpha * x[i] * dj;
                }
            }
        }
        error = 0;
        for (i = 0; i < 1; i++)
        {
            error += (t[i] - z[i]) * (t[i] - z[i]);
        }
        if ((count % 10000) != 0) continue;
        Serial.print_long(count, DEC);
        Serial.print(" ");
        Serial.print_double(error, 5);
        Serial.println(" ");
    }
    Serial.print("0,0 = ");
    x[1] = 0;
    x[2] = 0;
    for (i = 0; i <= 2; i++)
    {
        double sum = 0;
        for (j = 0; j <= 2; j++)
        {
            sum += w[j][i] * x[j];
        }
        g[i] = sigmoid(sum);
    }
    for (i = 0; i <= 2; i++)
    {
        double sum = 0;
        for (j = 0; j <= 2; j++)
        {
            sum += h[j][i] * g[j];
        }
        z[i] = sigmoid(sum);
    }
    Serial.print_double(z[0], 3);
    Serial.println(" ");
    Serial.print("0,1 = ");
    x[1] = 0;
    x[2] = 1;
    for (i = 0; i <= 2; i++)
    {
        double sum = 0;
        for (j = 0; j <= 2; j++)
        {
            sum += w[j][i] * x[j];
        }
        g[i] = sigmoid(sum);
    }
    for (i = 0; i <= 2; i++)
    {
        double sum = 0;
        for (j = 0; j <= 2; j++)
        {
            sum += h[j][i] * g[j];
        }
        z[i] = sigmoid(sum);
    }
    Serial.print_double(z[0], 3);
    Serial.println(" ");
    Serial.print("1,0 = ");
    x[1] = 1;
    x[2] = 0;
    for (i = 0; i <= 2; i++)
    {
        double sum = 0;
        for (j = 0; j <= 2; j++)
        {
            sum += w[j][i] * x[j];
        }
        g[i] = sigmoid(sum);
    }
    for (i = 0; i <= 2; i++)
    {
        double sum = 0;
        for (j = 0; j <= 2; j++)
        {
            sum += h[j][i] * g[j];
        }
        z[i] = sigmoid(sum);
    }
    Serial.print_double(z[0], 3);
    Serial.println(" ");
    Serial.print("1,1 = ");
    x[1] = 1;
    x[2] = 1;
    for (i = 0; i <= 2; i++)
    {
        double sum = 0;
        for (j = 0; j <= 2; j++)
        {
            sum += w[j][i] * x[j];
        }
        g[i] = sigmoid(sum);
    }
    for (i = 0; i <= 2; i++)
    {
        double sum = 0;
        for (j = 0; j <= 2; j++)
        {
            sum += h[j][i] * g[j];
        }
        z[i] = sigmoid(sum);
    }
    Serial.print_double(z[0], 3);
    Serial.println(" ");
}

void setup()
{
    Serial.begin(115200);
    Serial.println("ok");
    train();
}
void loop()
{
}


以上。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0