0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

最急降下法(勾配降下法)を可視化してみた

Last updated at Posted at 2024-08-28

私は機械学習とかディープラーニングなどについてはよく分からないのですが、せっかくグラフ表示のプログラムを作成したので、できれば他にもいろいろ表示させてみたいということで、最急降下法(勾配降下法)を可視化してみました。
最急降下法(勾配降下法)は関数の傾きに注目して探索を行う方法のようですが、私がいい加減に説明するのもどうかと思うので、詳しく知りたい方は参考にさせていただいたサイトを読むなり、ネットで検索していただくのがいいかと思います。

4本のプログラムは関数部分以外は全く同じなので、2本目からは関数部分のみを掲載しました。

勾配降下法について参考にさせていただいたサイト

その他参考にさせていただいたサイト

図1~図4に各関数の探索の様子をグラフに表示しました。

図1 $\quad5(x-1)^2+2(y-2)^2$
GradientDescent0003.gif

図2 $\quad x^2+y^2+20\sin^2 x$
GradientDescent0004.gif

図3 $\quad x^3+y^3-9xy+27$
GradientDescent0005.gif

図4 $\quad0.5x^4+0.1y^4-0.01x^3y^3-x-0.3y$
GradientDescent0006.gif

GradientDescent1.cpp
#define _USE_MATH_DEFINES
#include <math.h>
#include <iostream>
#include <fstream>
using namespace std;

// 0以上1以下の実数乱数
#define RAND_01 ((double)rand() / RAND_MAX)

string str_formula = "x**2 + y**2"; // グラフに表示する数式

const int ITERATION = 100; // 試行回数

const int SWARM_SIZE = 20; // 群の大きさ

double X_MAX = 10; // graphのX軸の表示範囲の上限
double X_MIN = -10; // graphのX軸の表示範囲の下限

double Y_MAX = 10; // graphのY軸の表示範囲の上限
double Y_MIN = -10; // graphのY軸の表示範囲の下限

double Z_MAX = 200; //  graphのZ軸の表示範囲の上限
double Z_MIN = 0; // graphのZ軸の表示範囲の下限

double x[ITERATION][SWARM_SIZE];
double y[ITERATION][SWARM_SIZE];
double z[ITERATION][SWARM_SIZE];

double function(double x, double y)
{
	double z;

	z = 5 * (x - 1) * (x - 1) + 2 * (y - 2) * (y - 2);

	return z;
}

void set_function_conditions()
{
	X_MAX = 4;
	X_MIN = -6;
	Y_MAX = 4;
	Y_MIN = -6;
	Z_MAX = 400;
	Z_MIN = -0;
	str_formula = "5 * (x - 1)**2 + 2 * (y - 2)**2";	
}

void write_data()
{
	ofstream fout;
	fout.open("graph_data.txt");

	fout << str_formula << endl;
	cout << str_formula << endl ;

	fout << "ITERATION " << ITERATION << endl;
	cout << "ITERATION " << ITERATION << endl;

	fout << "SWARM_SIZE " << SWARM_SIZE << endl;
	cout << "SWARM_SIZE " << SWARM_SIZE << endl;

	fout << "X_MAX " << X_MAX << endl;
	cout << "X_MAX " << X_MAX << endl;

	fout << "X_MIN " << X_MIN << endl;
	cout << "X_MIN " << X_MIN << endl;

	fout << "Y_MAX " << Y_MAX << endl;
	cout << "Y_MAX " << Y_MAX << endl;

	fout << "Y_MIN " << Y_MIN << endl;
	cout << "Y_MIN " << Y_MIN << endl;

	fout << "Z_MAX " << Z_MAX << endl;
	cout << "Z_MAX " << Z_MAX << endl;

	fout << "Z_MIN " << Z_MIN << endl;
	cout << "Z_MIN " << Z_MIN << endl;

	fout << "DATA " << endl;
	cout << "DATA " << endl;

	for (int i = 0; i < ITERATION; i++) {
		cout << i << endl;
		for (int j = 0; j < SWARM_SIZE; j++) {
			cout << x[i][j] << " " << y[i][j] << " " << z[i][j] << endl;
			fout << x[i][j] << " " << y[i][j] << " " << z[i][j] << endl;
		}
		cout << endl;
		fout << endl;
	}

	fout.close();
}

void initialize()
{
	for (int i = 0; i < SWARM_SIZE; i++) {
		x[0][i] = RAND_01 * (X_MAX - X_MIN) + X_MIN;
		y[0][i] = RAND_01 * (Y_MAX - Y_MIN) + Y_MIN;
		z[0][i] = function(x[0][i], y[0][i]);
	}
}

double lr = 0.02;
double grad_x, grad_y;

// fの勾配を返す
void numerical_gradient(double x1, double y1) {
	double h = 0.0001;

	double f1_x = function(x1 + h, y1);
	double f2_x = function(x1 - h, y1);
	grad_x = (f1_x - f2_x) / (2 * h);

	double f1_y = function(x1, y1 + h);
	double f2_y = function(x1, y1 - h);
	grad_y = (f1_y - f2_y) / (2 * h);
}

// 勾配法
void gradient_descent()
{
	for (int i = 1; i < ITERATION; i++) {
		for (int j = 0; j < SWARM_SIZE; j++) {
			numerical_gradient(x[i-1][j], y[i-1][j]);

			x[i][j] = x[i-1][j] - lr * grad_x;
			if (x[i][j] > X_MAX)
				x[i][j] = X_MAX;
			if (x[i][j] < X_MIN)
				x[i][j] = X_MIN;

			y[i][j] = y[i-1][j] - lr * grad_y;
			if (y[i][j] > Y_MAX)
				y[i][j] = Y_MAX;
			if (y[i][j] < Y_MIN)
				y[i][j] = Y_MIN;

			z[i][j] = function(x[i][j], y[i][j]);
		}
	}
}

int main()
{
	srand ((unsigned)time(NULL));

	set_function_conditions();

	initialize();
	gradient_descent();

	write_data();
}

GradientDescent2.cpp
double function(double x, double y)
{
	double z;

	z = x * x + y * y + 20.0 * sin(x) * sin(x);

	return z;
}

void set_function_conditions()
{
	X_MAX = 10;
	X_MIN = -10;
	Y_MAX = 10;
	Y_MIN = -10;
	Z_MAX = 200;
	Z_MIN = 0;
	str_formula = "x**2 + y**2 + 20.0 * sin(x)**2";
}

GradientDescent3.cpp
double function(double x, double y)
{
	double z;

	z = x * x * x + y * y * y - 9.0 * x * y + 27.0;

	return z;
}

void set_function_conditions()
{
	X_MAX = 10;
	X_MIN = -10;
	Y_MAX = 10;
	Y_MIN = -10;
	Z_MAX = 1500;
	Z_MIN = -3000;
	str_formula = "x**3 + y**3 - 9.0 * x * y + 27.0";
}

GradientDescent4.cpp
double function(double x, double y)
{
	double z;

	z = 0.5 * x * x * x * x + 0.1 * y * y * y * y - 0.01 * x * x * x * y * y * y - x - 0.3 * y;

	return z;
}

void set_function_conditions()
{
	X_MAX = 6;
	X_MIN = -6;
	Y_MAX = 6;
	Y_MIN = -6;
	Z_MAX = 1000;
	Z_MIN = 0;
	str_formula = "0.5 * x**4 + 0.1 * y**4 - 0.01 * x**3 * y**3 - x - 0.3 * y";
}

0
0
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?