2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

OpenGLを用いての単純SpikingNeuron

Last updated at Posted at 2018-05-30

SpikingNeuralNetwork(SNN)について調べていたんだけど、全然文献が見つからない…。

英語で書かれている論文は少しあるんだけど全然読めない(英語超苦手)。

いろんなところの解説を少しずつ読んでいって、なんとか簡単なプログラムを書くことができたので、ここでまとめておきます。

SpikingNeuralNetwork(SNN)とは

スパイキングニューラルネットワークの略で、ニューラルネットワークの一種です。

ニューラルネットワークとは人間の脳神経のニューロンを数理モデル化した単純パーセプトロンを複数組み合わせたものです。

多層パーセプトロン.jpg

こんな感じ。

ニューラルネットワークは様々な種類が考案されていて、それぞれ得意分野が異なります。

例えば、出力を次の入力として扱うRecurrentNeuralNetwork(RNN)は自然言語処理の分野で大活躍しているし、通常の層と違い、畳み込み層とプーリング層の特殊な層を含む構造であるConvolutionalNeuralNetwork(CNN)は画像処理でよく使われます。

このように、様々な形のニューラルネットワークが存在しますが、その中でも、脳の働きを最も忠実に再現したものがSpikingNeuralNetwork(SNN)です。

ニューロンは電気信号を伝えるのですが、その電気信号を活動電位と呼びます。
ニューロンの活動電位はずっと流れているわけでなく、一瞬だけ電位が急激に上がり、その後すぐに下がります。

スクリーンショット 2018-05-30 14.47.08.png

こんな感じ。

この、急激に電位が上がることを「スパイク(Spike)」といいます。

このスパイクを再現しているニューラルネットワークをSpikingNeuralNetwork(SNN)と呼ぶのです。

ユニットに入力値から電位を貯めていき、電位が閾値以上になったら出力をします。

このようなユニットを複数組み合わせてSNNを構成します。

SNN.png

コード

勉強した成果の確認のために簡単なSNNのコードを書いてみました。

C言語で書いており、結果描写にはOpenGLを用いています。

spiking.c

#include <stdio.h>
#include <OpenGL/OpenGL.h>
#include <GLUT/GLUT.h>
#include <math.h>
#include <string.h>
#include <time.h>
#include <stdlib.h>

#define inunit 3    //入力次元数

const int xmin=0, xmax=1200,ymin=0, ymax=800;
const int gramin = 50, gramax = 1150;
int cflag = 0;
int window;

//min〜numまでの乱数を生成
int ran(int num,int min){
    return rand()%num+min;
}

void makecircle(short cx, short cy, float r, int c)     //円の中心x座標、y座標、半径、色の受け取り c=0:緑 c=1:赤 c=2:青 c=3: 黄色 その他:桃
{
    int i,n=100;        //n=分割数、r=半径
    double rate;
    float cx2,cy2;
    
    if(c == 0)
        glColor3f(0.0f, 1.0f, 0.0f);
    else if(c == 1)
        glColor3f(1.0f, 0.0f, 0.0f);
    else if(c == 2)
        glColor3f(0.0f, 0.0f, 1.0f);
    else if(c == 3)
        glColor3f(1.0f, 0.5f, 0.0f);
    else
        glColor3f(0.0f, 0.0f, 0.0f);
    
    glBegin(GL_POLYGON);                           //ポリゴンの頂点記述開始
    for (i = 0; i < n; i++) {
        rate = (double)i / n;
        cx2 = (float)cx + r * cos(2.0 * M_PI * rate);       //円の描写
        cy2 = (float)cy + r * sin(2.0 * M_PI * rate);       // cx、cyは円の中心座標
        glVertex3s(cx2, cy2, 0);            //ポリゴン頂点の座標
    }
    glEnd();                    //頂点記述の終了
}


void makeLine(short x1, short y1, short x2, short y2,int c){
    
    if(c == 0)
        glColor3f(0.0f, 1.0f, 0.0f);
    else if(c == 1)
        glColor3f(1.0f, 0.0f, 0.0f);
    else if(c == 2)
        glColor3f(0.0f, 0.0f, 1.0f);
    else if(c == 3)
        glColor3f(1.0f, 0.5f, 0.0f);
    else
        glColor3f(0.0f, 0.0f, 0.0f);
    
    glBegin(GL_LINES);
    glVertex3s(x1, y1, 0);
    glVertex3s(x2, y2, 0);
    glEnd();
    
}


//グラフの描画
void initGraph(){
    
    int in1len = 100;
    int in2len = 100;
    int in3len = 100;
    int outlen = 200;
    int span = 20;
    
    //背景の塗りつぶし
    glClearColor(1.0, 1.0, 1.0, 1.0);
    glClear(GL_COLOR_BUFFER_BIT);
    
    //入力1グラフの描写
    makeLine(xmin+50, span, xmin+50, span+in1len, 4);
    makeLine(xmin+50, span+in1len, xmax-50, span+in1len, 4);
    
    //入力2グラフの描写
    makeLine(xmin+50, span*2+in1len, xmin+50, span*2+in1len+in2len, 4);
    makeLine(xmin+50, span*2+in1len+in2len, xmax-50, span*2+in1len+in2len, 4);
    
    //入力3グラフの描写
    makeLine(xmin+50, span*3+in1len+in2len, xmin+50, span*3+in1len+in2len+in3len, 4);
    makeLine(xmin+50, span*3+in1len+in2len+in3len, xmax-50, span*3+in1len+in2len+in3len, 4);
    
    //出力グラフの描写
    makeLine(xmin+50, span*4+in1len+in2len+in3len, xmin+50, span*4+in1len+in2len+in3len+outlen, 4);
    makeLine(xmin+50, span*4+in1len+in2len+in3len+outlen, xmax-50, span*4+in1len+in2len+in3len+outlen, 4);
    
    makeLine(xmin+50, span*5+in1len+in2len+in3len+outlen, xmin+50, span*5+in1len+in2len+in3len+outlen+100, 4);
    makeLine(xmin+50, span*5+in1len+in2len+in3len+outlen+100, xmax-50, span*5+in1len+in2len+in3len+outlen+100, 4);
    
}


void spiking(){
    
    int inputs[inunit];       //入力
    int output;               //出力
    static int weight[inunit] = { 4, -2, 3};    //ニューロンの重み
    static int t = 0;         //グラフがリセットされてからの経過時間
    static int solidt = 0;       //開始からの時間
    static int over = 0;         //
    int leak = 1;                //減衰率
    static int v = 0;            //電圧
    static int prev = 0;         //1ステップ前の電圧
    int thresh = 8;              //閾値
    int spike = 4;               //スパイクしたときの上昇値
    int spikeFlag = 0;           //スパイクしたかどうか
    static int tNext = 0;        //次にアクティブになるまでの時間
    int latency = 2;             //スパイク後の
    
    //入力値の決定
    for(int i=0; i<inunit; i++){
        inputs[i] = ran(2,0);
    }
    

    //最初にパラメータを表示
    if(t == 0){
        //重みの表示
        printf("The weights are:  ");
        for(int i=0; i < inunit; i++){
            printf("%d  ",weight[i]);
        }
        puts("\n");
    
        //パラメータの表示
        printf("The leak potential is : %d\n",leak);
        printf("The threshold is      : %d\n",thresh);
        printf("The spike is          : %d\n",spike);
        printf("The latency time is   : %d\n", latency);
        puts("\n");
    
        //グラフの描画
        initGraph();
    }
    
    //グラフの描画範囲を超えたらリセット
    if(((t+1)*7+50) >= (xmax-50)){
        t = 0;
        initGraph();
    }
   
        
    if(solidt != tNext){
        output = 0;
        prev = v;
    }else{
        int sum = 0;
        for(int j=0; j<inunit; j++)
            sum += inputs[j] * weight[j];
        prev = v;
        v = v + sum;
        v = v - leak;
        if(v<0)
            v = 0;
        if(v>=thresh){
            spikeFlag = 1;
            output = 1;
            tNext = solidt + 1 + latency;
        }else{
            output = 0;
            tNext = solidt + 1;
        }
    }

    
    //閾値の描画
    makeLine(0, -thresh * 10 + 580, xmax, -thresh * 10 + 580,0);

    //入力の描画
    if(inputs[0] == 1){
        makeLine((t+1)*7-2+50, 120,(t+1)*7-2+50,40,4);
    }
    if(inputs[1] == 1){
        makeLine((t+1)*7-2+50, 240,(t+1)*7-2+50,160,4);
    }
    if(inputs[2] == 1){
        makeLine((t+1)*7-2+50, 360,(t+1)*7-2+50,280,4);
    }
    
    //出力の描写
    if(output == 1){
        makeLine((t+1)*7-2+50, 700, (t+1)*7-2+50,620,4);
    }
    
    //SNNグラフの描画
    if(spikeFlag == 1){
        makeLine(t*7+50, -prev*10+580, (t+1)*7-2+50, -v*10+580,2);
        makecircle((t+1)*7-2+50, -v*10+580, 2, 1);
        prev = v;
        v += spike;
        makeLine((t+1)*7-2+50, -prev*10+580, (t+1)*7+50, -v*10+580,1);
        makecircle((t+1)*7+50, -v*10+580, 2, 1);
        prev = v;
        v = 0;
        makeLine((t+1)*7+50, -prev*10+580, (t+1)*7+50, -v*10+580,2);
        makecircle((t+1)*7+50, -v*10+580,2,1);
        spikeFlag = 0;
    }else{
        makeLine(t*7+50, -prev*10+580, (t+1)*7+50, -v*10+580,2);
        makecircle((t+1)*7+50, -v*10+580,2,1);
    }
    
    solidt++;
    t++;
    
}


void DrawGLScene_2D()
{
    glLoadIdentity();
    
    spiking();
    
    glFlush();
}

void ReSizeGLScene_2D(int Width, int Height)
{
    glViewport(0,0,Width,Height);
    glMatrixMode(GL_PROJECTION);
    glLoadIdentity();
    glOrtho (0, Width, Height, 0, -1.0f, 1.0f);
    glMatrixMode(GL_MODELVIEW);
}

void InitGL_2D(int Width, int Height)
{
    glClearColor(1.0f, 1.0f, 1.0f, 1.0f);
    glClearDepth(1.0);
    glDepthFunc(GL_LESS);
    glDisable(GL_DEPTH_TEST);
    glEnable(GL_BLEND);
    glBlendFunc (GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
    glShadeModel(GL_SMOOTH);
    ReSizeGLScene_2D(Width, Height);
}

void keyPressed(unsigned char key, int x1, int y1)
{
    if (key == 'c'){
        cflag = 1;
    }
}

#pragma mark Timer
int tcounter=0;
GLfloat top = -0.9;
void cal_env(int vl)
{
    static GLboolean isUp = GL_TRUE;
    
    if (top > 0.9F) isUp = GL_FALSE;
    else if (top <= -0.9F) isUp = GL_TRUE;
    top += (isUp == GL_TRUE ? 0.01 : -0.01);
    
    DrawGLScene_2D();
    
    glutSetWindow(window);
    glutPostRedisplay();
    glutTimerFunc(100 , cal_env , 0);
    tcounter++;
}

int main(int argc, char *argv[])
{
    //乱数のシード値を時間によって変更
    srand((unsigned int)time(NULL));
    
    glutInit(&argc, argv);
    glutInitDisplayMode(GLUT_RGBA | GLUT_ALPHA | GLUT_DEPTH);
    glutInitWindowSize(xmax, ymax);
    glutInitWindowPosition(0, 0);
    
    window = glutCreateWindow("Simulation");
    glutDisplayFunc(&DrawGLScene_2D);
    glutReshapeFunc(&ReSizeGLScene_2D);
    glutTimerFunc(100, cal_env, 0);
    
    glutKeyboardFunc(&keyPressed);
    InitGL_2D(xmax, ymax);
    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
    glutMainLoop();
    return 0;
}

結果

うまくいきました

スクリーンショット 2018-05-30 14.39.40.png

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?