SpikingNeuralNetwork(SNN)について調べていたんだけど、全然文献が見つからない…。
英語で書かれている論文は少しあるんだけど全然読めない(英語超苦手)。
いろんなところの解説を少しずつ読んでいって、なんとか簡単なプログラムを書くことができたので、ここでまとめておきます。
SpikingNeuralNetwork(SNN)とは
スパイキングニューラルネットワークの略で、ニューラルネットワークの一種です。
ニューラルネットワークとは人間の脳神経のニューロンを数理モデル化した単純パーセプトロンを複数組み合わせたものです。
こんな感じ。
ニューラルネットワークは様々な種類が考案されていて、それぞれ得意分野が異なります。
例えば、出力を次の入力として扱うRecurrentNeuralNetwork(RNN)は自然言語処理の分野で大活躍しているし、通常の層と違い、畳み込み層とプーリング層の特殊な層を含む構造であるConvolutionalNeuralNetwork(CNN)は画像処理でよく使われます。
このように、様々な形のニューラルネットワークが存在しますが、その中でも、脳の働きを最も忠実に再現したものがSpikingNeuralNetwork(SNN)です。
ニューロンは電気信号を伝えるのですが、その電気信号を活動電位と呼びます。
ニューロンの活動電位はずっと流れているわけでなく、一瞬だけ電位が急激に上がり、その後すぐに下がります。
こんな感じ。
この、急激に電位が上がることを「スパイク(Spike)」といいます。
このスパイクを再現しているニューラルネットワークをSpikingNeuralNetwork(SNN)と呼ぶのです。
ユニットに入力値から電位を貯めていき、電位が閾値以上になったら出力をします。
このようなユニットを複数組み合わせてSNNを構成します。
コード
勉強した成果の確認のために簡単なSNNのコードを書いてみました。
C言語で書いており、結果描写にはOpenGLを用いています。
#include <stdio.h>
#include <OpenGL/OpenGL.h>
#include <GLUT/GLUT.h>
#include <math.h>
#include <string.h>
#include <time.h>
#include <stdlib.h>
#define inunit 3 //入力次元数
const int xmin=0, xmax=1200,ymin=0, ymax=800;
const int gramin = 50, gramax = 1150;
int cflag = 0;
int window;
//min〜numまでの乱数を生成
int ran(int num,int min){
return rand()%num+min;
}
void makecircle(short cx, short cy, float r, int c) //円の中心x座標、y座標、半径、色の受け取り c=0:緑 c=1:赤 c=2:青 c=3: 黄色 その他:桃
{
int i,n=100; //n=分割数、r=半径
double rate;
float cx2,cy2;
if(c == 0)
glColor3f(0.0f, 1.0f, 0.0f);
else if(c == 1)
glColor3f(1.0f, 0.0f, 0.0f);
else if(c == 2)
glColor3f(0.0f, 0.0f, 1.0f);
else if(c == 3)
glColor3f(1.0f, 0.5f, 0.0f);
else
glColor3f(0.0f, 0.0f, 0.0f);
glBegin(GL_POLYGON); //ポリゴンの頂点記述開始
for (i = 0; i < n; i++) {
rate = (double)i / n;
cx2 = (float)cx + r * cos(2.0 * M_PI * rate); //円の描写
cy2 = (float)cy + r * sin(2.0 * M_PI * rate); // cx、cyは円の中心座標
glVertex3s(cx2, cy2, 0); //ポリゴン頂点の座標
}
glEnd(); //頂点記述の終了
}
void makeLine(short x1, short y1, short x2, short y2,int c){
if(c == 0)
glColor3f(0.0f, 1.0f, 0.0f);
else if(c == 1)
glColor3f(1.0f, 0.0f, 0.0f);
else if(c == 2)
glColor3f(0.0f, 0.0f, 1.0f);
else if(c == 3)
glColor3f(1.0f, 0.5f, 0.0f);
else
glColor3f(0.0f, 0.0f, 0.0f);
glBegin(GL_LINES);
glVertex3s(x1, y1, 0);
glVertex3s(x2, y2, 0);
glEnd();
}
//グラフの描画
void initGraph(){
int in1len = 100;
int in2len = 100;
int in3len = 100;
int outlen = 200;
int span = 20;
//背景の塗りつぶし
glClearColor(1.0, 1.0, 1.0, 1.0);
glClear(GL_COLOR_BUFFER_BIT);
//入力1グラフの描写
makeLine(xmin+50, span, xmin+50, span+in1len, 4);
makeLine(xmin+50, span+in1len, xmax-50, span+in1len, 4);
//入力2グラフの描写
makeLine(xmin+50, span*2+in1len, xmin+50, span*2+in1len+in2len, 4);
makeLine(xmin+50, span*2+in1len+in2len, xmax-50, span*2+in1len+in2len, 4);
//入力3グラフの描写
makeLine(xmin+50, span*3+in1len+in2len, xmin+50, span*3+in1len+in2len+in3len, 4);
makeLine(xmin+50, span*3+in1len+in2len+in3len, xmax-50, span*3+in1len+in2len+in3len, 4);
//出力グラフの描写
makeLine(xmin+50, span*4+in1len+in2len+in3len, xmin+50, span*4+in1len+in2len+in3len+outlen, 4);
makeLine(xmin+50, span*4+in1len+in2len+in3len+outlen, xmax-50, span*4+in1len+in2len+in3len+outlen, 4);
makeLine(xmin+50, span*5+in1len+in2len+in3len+outlen, xmin+50, span*5+in1len+in2len+in3len+outlen+100, 4);
makeLine(xmin+50, span*5+in1len+in2len+in3len+outlen+100, xmax-50, span*5+in1len+in2len+in3len+outlen+100, 4);
}
void spiking(){
int inputs[inunit]; //入力
int output; //出力
static int weight[inunit] = { 4, -2, 3}; //ニューロンの重み
static int t = 0; //グラフがリセットされてからの経過時間
static int solidt = 0; //開始からの時間
static int over = 0; //
int leak = 1; //減衰率
static int v = 0; //電圧
static int prev = 0; //1ステップ前の電圧
int thresh = 8; //閾値
int spike = 4; //スパイクしたときの上昇値
int spikeFlag = 0; //スパイクしたかどうか
static int tNext = 0; //次にアクティブになるまでの時間
int latency = 2; //スパイク後の
//入力値の決定
for(int i=0; i<inunit; i++){
inputs[i] = ran(2,0);
}
//最初にパラメータを表示
if(t == 0){
//重みの表示
printf("The weights are: ");
for(int i=0; i < inunit; i++){
printf("%d ",weight[i]);
}
puts("\n");
//パラメータの表示
printf("The leak potential is : %d\n",leak);
printf("The threshold is : %d\n",thresh);
printf("The spike is : %d\n",spike);
printf("The latency time is : %d\n", latency);
puts("\n");
//グラフの描画
initGraph();
}
//グラフの描画範囲を超えたらリセット
if(((t+1)*7+50) >= (xmax-50)){
t = 0;
initGraph();
}
if(solidt != tNext){
output = 0;
prev = v;
}else{
int sum = 0;
for(int j=0; j<inunit; j++)
sum += inputs[j] * weight[j];
prev = v;
v = v + sum;
v = v - leak;
if(v<0)
v = 0;
if(v>=thresh){
spikeFlag = 1;
output = 1;
tNext = solidt + 1 + latency;
}else{
output = 0;
tNext = solidt + 1;
}
}
//閾値の描画
makeLine(0, -thresh * 10 + 580, xmax, -thresh * 10 + 580,0);
//入力の描画
if(inputs[0] == 1){
makeLine((t+1)*7-2+50, 120,(t+1)*7-2+50,40,4);
}
if(inputs[1] == 1){
makeLine((t+1)*7-2+50, 240,(t+1)*7-2+50,160,4);
}
if(inputs[2] == 1){
makeLine((t+1)*7-2+50, 360,(t+1)*7-2+50,280,4);
}
//出力の描写
if(output == 1){
makeLine((t+1)*7-2+50, 700, (t+1)*7-2+50,620,4);
}
//SNNグラフの描画
if(spikeFlag == 1){
makeLine(t*7+50, -prev*10+580, (t+1)*7-2+50, -v*10+580,2);
makecircle((t+1)*7-2+50, -v*10+580, 2, 1);
prev = v;
v += spike;
makeLine((t+1)*7-2+50, -prev*10+580, (t+1)*7+50, -v*10+580,1);
makecircle((t+1)*7+50, -v*10+580, 2, 1);
prev = v;
v = 0;
makeLine((t+1)*7+50, -prev*10+580, (t+1)*7+50, -v*10+580,2);
makecircle((t+1)*7+50, -v*10+580,2,1);
spikeFlag = 0;
}else{
makeLine(t*7+50, -prev*10+580, (t+1)*7+50, -v*10+580,2);
makecircle((t+1)*7+50, -v*10+580,2,1);
}
solidt++;
t++;
}
void DrawGLScene_2D()
{
glLoadIdentity();
spiking();
glFlush();
}
void ReSizeGLScene_2D(int Width, int Height)
{
glViewport(0,0,Width,Height);
glMatrixMode(GL_PROJECTION);
glLoadIdentity();
glOrtho (0, Width, Height, 0, -1.0f, 1.0f);
glMatrixMode(GL_MODELVIEW);
}
void InitGL_2D(int Width, int Height)
{
glClearColor(1.0f, 1.0f, 1.0f, 1.0f);
glClearDepth(1.0);
glDepthFunc(GL_LESS);
glDisable(GL_DEPTH_TEST);
glEnable(GL_BLEND);
glBlendFunc (GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);
glShadeModel(GL_SMOOTH);
ReSizeGLScene_2D(Width, Height);
}
void keyPressed(unsigned char key, int x1, int y1)
{
if (key == 'c'){
cflag = 1;
}
}
#pragma mark Timer
int tcounter=0;
GLfloat top = -0.9;
void cal_env(int vl)
{
static GLboolean isUp = GL_TRUE;
if (top > 0.9F) isUp = GL_FALSE;
else if (top <= -0.9F) isUp = GL_TRUE;
top += (isUp == GL_TRUE ? 0.01 : -0.01);
DrawGLScene_2D();
glutSetWindow(window);
glutPostRedisplay();
glutTimerFunc(100 , cal_env , 0);
tcounter++;
}
int main(int argc, char *argv[])
{
//乱数のシード値を時間によって変更
srand((unsigned int)time(NULL));
glutInit(&argc, argv);
glutInitDisplayMode(GLUT_RGBA | GLUT_ALPHA | GLUT_DEPTH);
glutInitWindowSize(xmax, ymax);
glutInitWindowPosition(0, 0);
window = glutCreateWindow("Simulation");
glutDisplayFunc(&DrawGLScene_2D);
glutReshapeFunc(&ReSizeGLScene_2D);
glutTimerFunc(100, cal_env, 0);
glutKeyboardFunc(&keyPressed);
InitGL_2D(xmax, ymax);
glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
glutMainLoop();
return 0;
}
結果
うまくいきました