概要
4x4リバーシのPythonプログラムを作成しましたが、高速化するため、C++で記述し、pybind11を使ってPythonから呼び出せるようにします。ロジックはPythonプログラムと同一です。
環境
ホストOS Ubuntu18.04 にdockerをインストール、docker内でPythonを動かしています。
- Ubuntu18.04
- nvidia-driver-410
- nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
- Python 3.7.3
- tensorflow-gpu 1.13.1
- g++ (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
プログラム
配置
deep-learning-reversi/
├ dlreversi4/
│ ├ cpp/
│ │ ├ Game.hpp
│ │ ├ State.cpp
│ │ └ State.cpp
│ └ game.py
├ performance_check.py
└ setup.py
State
ゲームの状態を表すクラスです。game.pyのStateクラスに当たります。
###ヘッダーファイル
State.hpp
#ifndef STATE_HPP_INCLUDED
#define STATE_HPP_INCLUDED
#include <stdint.h>
#include <array>
#include <vector>
class State
{
public:
static const unsigned BOARD_SIZE = 4;
static const unsigned BOARD_AREA = BOARD_SIZE * BOARD_SIZE;
static const unsigned PASS = BOARD_SIZE * BOARD_SIZE;
State();
State(const std::array<unsigned, BOARD_AREA> pieces, const std::array<unsigned, BOARD_AREA> enemyPieces, const unsigned depth);
~State() {}
// 負けかどうか
bool isLose() const;
// 引き分けかどうか
bool isDraw() const;
// ゲーム終了かどうか
bool isDone() const;
// 先手かどうか
bool isFirstPlayer() const;
// 次の状態の取得
State next(const unsigned& action);
// 文字列化
std::string toString() const;
// 石の数
unsigned piecesCount() const;
unsigned enemyPiecesCount() const;
// 石の配置
std::array<unsigned, BOARD_AREA> pieces;
std::array<unsigned, BOARD_AREA> enemyPieces;
// 石の配置(ビット表現)
std::uint16_t piecesBit;
std::uint16_t enemyPiecesBit;
// 手数
unsigned depth;
// 合法手
std::vector<unsigned> legalActions;
private:
static const std::uint16_t MASK_EDGE = 0x0660;
static const std::uint16_t MASK_VERTICAL = 0x6666;
static const std::uint16_t MASK_HORIZONTAL = 0x0FF0;
static const unsigned DIR_LEFT_OBRIQUE = 5;
static const unsigned DIR_RIGHT_OBLIQUE = 3;
static const unsigned DIR_HORIZONTAL = 1;
static const unsigned DIR_VERTICAL = 4;
// 石の数の取得
unsigned countBit(const std::uint16_t& bit) const;
// 合法手の取得
std::uint16_t generateLegal() const;
std::uint16_t generateSomeLegal(const std::uint16_t& pieces_bit, const std::uint16_t& enemy_pieces_bit, const unsigned& direction) const;
// 任意のマスが合法手かどうか
bool isLegalAction(const std::uint16_t& actionBit) const;
// 石を置く。
void move(const std::uint16_t& actionBit);
std::uint16_t generateSomeFlipped(
const std::uint16_t& piecesBit, const std::uint16_t& enemyPiecesBit, const std::uint16_t& actionBit, const unsigned& direction) const;
// 石の配置を配列に変換
std::array<unsigned, BOARD_AREA> bitToArray(const std::uint16_t& bit) const;
// 石の配列表現をビットに変換
std::uint16_t arrayToBit(const std::array<unsigned, BOARD_AREA>& array) const;
// 合法手を配列に変換
std::vector<unsigned> legalActionToAraay(const std::uint16_t& bit) const;
// 連続パスによる終了
bool passEnd;
// 合法手(ビット表現)
std::uint16_t legalActionsBit;
};
#endif
ソースファイル
State.cpp
#include "State.hpp"
#include <iostream>
#include <sstream>
// ゲーム状態
State::State()
{
// 連続パスによる終了
passEnd = false;
// 石の配置
this->pieces = { 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0 };
this->enemyPieces = { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 };
this->piecesBit = arrayToBit(this->pieces);
this->enemyPiecesBit = arrayToBit(this->enemyPieces);
// 手数
this->depth = 0;
// 合法手の設定
this->legalActionsBit = generateLegal();
this->legalActions = legalActionToAraay(this->legalActionsBit);
}
// ゲーム状態
State::State(const std::array<unsigned, BOARD_AREA> pieces, const std::array<unsigned, BOARD_AREA> enemyPieces, const unsigned depth)
{
// 連続パスによる終了
passEnd = false;
// 石の配置
this->pieces = pieces;
this->enemyPieces = enemyPieces;
this->piecesBit = arrayToBit(this->pieces);
this->enemyPiecesBit = arrayToBit(this->enemyPieces);
// 手数
this->depth = depth;
// 合法手の設定
this->legalActionsBit = generateLegal();
this->legalActions = legalActionToAraay(this->legalActionsBit);
}
// 負けかどうか
bool State::isLose() const
{
return isDone() && (countBit(this->piecesBit) < countBit(this->enemyPiecesBit));
}
// 引き分けかどうか
bool State::isDraw() const
{
return isDone() && (countBit(this->piecesBit) == countBit(this->enemyPiecesBit));
}
// ゲーム終了かどうか
bool State::isDone() const
{
return (countBit(this->piecesBit | this->enemyPiecesBit) == BOARD_AREA) || passEnd;
}
// 先手かどうか
bool State::isFirstPlayer() const
{
return this->depth % 2 == 0;
}
// 次の状態の取得
State State::next(const unsigned& action)
{
std::uint16_t actionBit = 0x8000 >> action;
State state(this->pieces, this->enemyPieces, this->depth + 1);
if (actionBit != 0 && isLegalAction(actionBit))
state.move(actionBit);
std::array<unsigned, BOARD_AREA> w = state.pieces;
std::uint16_t wb = state.piecesBit;
state.pieces = state.enemyPieces;
state.piecesBit = state.enemyPiecesBit;
state.enemyPieces = w;
state.enemyPiecesBit = wb;
// 合法手の設定
state.legalActionsBit = state.generateLegal();
state.legalActions = legalActionToAraay(state.legalActionsBit);
// 2回連続パス判定
if (actionBit == 0 && state.legalActionsBit == 0)
state.passEnd = true;
return state;
}
// 石の数
unsigned State::piecesCount() const
{
return countBit(this->piecesBit);
}
unsigned State::enemyPiecesCount() const
{
return countBit(this->enemyPiecesBit);
}
// 石の数の取得
unsigned State::countBit(const std::uint16_t& bit) const
{
std::uint16_t count = (bit & 0x5555) + ((bit >> 1) & 0x5555);
count = (count & 0x3333) + ((count >> 2) & 0x3333);
count = (count & 0x0f0f) + ((count >> 4) & 0x0f0f);
return (count & 0x00ff) + ((count >> 8) & 0x00ff);
}
// 合法手の取得
std::uint16_t State::generateLegal() const
{
return ~(this->piecesBit | this->enemyPiecesBit)
& (generateSomeLegal(this->piecesBit, this->enemyPiecesBit & MASK_EDGE, DIR_LEFT_OBRIQUE)
| generateSomeLegal(this->piecesBit, this->enemyPiecesBit & MASK_EDGE, DIR_RIGHT_OBLIQUE)
| generateSomeLegal(this->piecesBit, this->enemyPiecesBit & MASK_VERTICAL, DIR_HORIZONTAL)
| generateSomeLegal(this->piecesBit, this->enemyPiecesBit & MASK_HORIZONTAL, DIR_VERTICAL));
}
inline std::uint16_t State::generateSomeLegal(const std::uint16_t& pieces_bit, const std::uint16_t& enemy_pieces_bit, const unsigned& direction) const
{
std::uint16_t flipped = ((pieces_bit << direction) | (pieces_bit >> direction)) & enemy_pieces_bit;
for (unsigned i = 0; i < 6; i++)
flipped |= ((flipped << direction) | (flipped >> direction)) & enemy_pieces_bit;
return flipped << direction | flipped >> direction;
}
// 任意のマスが合法手かどうか
inline bool State::isLegalAction(const std::uint16_t& actionBit) const
{
return actionBit && this->legalActionsBit != 0;
}
// 石を置く。
void State::move(const std::uint16_t& actionBit)
{
std::uint16_t flipped = generateSomeFlipped(this->piecesBit, this->enemyPiecesBit & MASK_EDGE, actionBit, DIR_LEFT_OBRIQUE);
flipped |= generateSomeFlipped(this->piecesBit, this->enemyPiecesBit & MASK_EDGE, actionBit, DIR_RIGHT_OBLIQUE);
flipped |= generateSomeFlipped(this->piecesBit, this->enemyPiecesBit & MASK_VERTICAL, actionBit, DIR_HORIZONTAL);
flipped |= generateSomeFlipped(this->piecesBit, this->enemyPiecesBit & MASK_HORIZONTAL, actionBit, DIR_VERTICAL);
this->piecesBit = this->piecesBit | actionBit | flipped;
this->enemyPiecesBit = this->enemyPiecesBit ^ flipped;
this->pieces = bitToArray(this->piecesBit);
this->enemyPieces = bitToArray(this->enemyPiecesBit);
}
inline std::uint16_t State::generateSomeFlipped(
const std::uint16_t& piecesBit, const std::uint16_t& enemyPiecesBit, const std::uint16_t& actionBit, const unsigned& direction) const
{
const uint16_t leftEnemy = (actionBit << direction) & enemyPiecesBit;
const uint16_t rightEnemy = (actionBit >> direction) & enemyPiecesBit;
const uint16_t leftSelf = (piecesBit << direction) & enemyPiecesBit;
const uint16_t rightSelf = (piecesBit >> direction) & enemyPiecesBit;
return ((leftEnemy & (rightSelf | (rightSelf >> direction))) |
((leftEnemy << direction) & rightSelf) | (rightEnemy & (leftSelf | (leftSelf << direction))) |
((rightEnemy >> direction) & leftSelf));
}
// 石のビット表現を配列に変換
inline std::array<unsigned, State::BOARD_AREA> State::bitToArray(const std::uint16_t& bit) const
{
std::array<unsigned, BOARD_AREA> array;
std::uint16_t mask = 0x8000;
for (unsigned i = 0; i < BOARD_AREA; ++i)
{
if ((bit & mask) == 0)
array[i] = 0;
else
array[i] = 1;
mask = mask >> 1;
}
return array;
}
// 石の配列表現をビットに変換
inline std::uint16_t State::arrayToBit(const std::array<unsigned, BOARD_AREA>& array) const
{
std::uint16_t bit = 0;
std::uint16_t mask = 0x8000;
for (unsigned i = 0; i < array.size(); ++i)
{
if (array[i] > 0)
bit |= mask;
mask = mask >> 1;
}
return bit;
}
// 合法手を配列に変換
inline std::vector<unsigned> State::legalActionToAraay(const std::uint16_t& bit) const
{
if (bit == 0)
{
std::vector<unsigned> vec(1);
vec[0] = PASS;
return vec;
}
std::vector<unsigned> vec(countBit(bit));
std::uint16_t mask = 0x8000;
unsigned index = 0;
for (unsigned i = 0; i < BOARD_AREA; ++i)
{
if ((bit & mask) != 0)
vec[index++] = i;
mask = mask >> 1;
}
return vec;
}
// 文字列化
std::string State::toString() const
{
std::uint16_t mask = 0x8000;
std::array<std::string, 2> ox;
if (isFirstPlayer())
ox = { "o", "x" };
else
ox = { "x", "o" };
std::stringstream ss;
for (unsigned i = 0; i < BOARD_AREA; i++)
{
if ((this->piecesBit & mask) != 0)
ss << ox[0];
else if ((this->enemyPiecesBit & mask) != 0)
ss << ox[1];
else
ss << '-';
if (i % 4 == 3)
ss << std::endl;
mask = mask >> 1;
}
return ss.str();
}
Game
Game.hpp
#ifndef GAME_HPP_INCLUDED
#define GAME_HPP_INCLUDED
#include "State.hpp"
#include <random>
// ランダムで行動選択
unsigned randomAction(const State& state)
{
const unsigned size = state.legalActions.size();
if (size >= 2)
{
std::random_device rnd; // 非決定的な乱数生成器
std::mt19937 mt(rnd()); // メルセンヌ・ツイスタの32ビット版
std::uniform_real_distribution<> rand01(0, size - 1);
unsigned index = rand01(mt);
return state.legalActions[index];
}
else if (size == 1)
{
return state.legalActions[0];
}
else
{
return State::PASS;
}
}
#endif
reveri4.cpp
reversi4.cpp
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cstdint>
#include <string>
#include "State.hpp"
#include "Game.hpp"
namespace py = pybind11;
PYBIND11_MODULE(cpp, m)
{
m.def("random_action", &randomAction);
py::class_<State> state(m, "State");
state
.def(py::init<>())
.def(py::init<std::array<unsigned, 16>, std::array<unsigned, 16>, unsigned>())
.def_readonly("pieces", &State::pieces)
.def_readonly("enemy_pieces", &State::enemyPieces)
.def_readonly("pieces_bit", &State::piecesBit)
.def_readonly("enemy_pieces_bit", &State::enemyPiecesBit)
.def_readonly("depth", &State::depth)
.def_readonly("legal_actions", &State::legalActions)
.def("is_lose", &State::isLose)
.def("is_draw", &State::isDraw)
.def("is_done", &State::isDone)
.def("is_first_player", &State::isFirstPlayer)
.def("next", &State::next)
.def("pieces_count", &State::piecesCount)
.def("enemy_pieces_count", &State::enemyPiecesCount)
.def("__str__", &State::toString);
m.doc() = "cpp eversi4 plugin";
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}
setup.py
setup.py
from setuptools import setup
from setuptools import find_packages
import os, sys
import pybind11
from setuptools import Extension
from distutils import sysconfig
# Remove the "-Wstrict-prototypes" compiler option, which isn't valid for C++.
import distutils.sysconfig
cfg_vars = distutils.sysconfig.get_config_vars()
for key, value in cfg_vars.items():
if type(value) == str:
cfg_vars[key] = value.replace("-Wstrict-prototypes", "")
kwds = dict(
extra_compile_args=['-std=c++11'],
include_dirs=[
os.path.dirname(pybind11.get_include(True)),
os.path.dirname(pybind11.get_include(False))
],
)
if sys.platform == 'darwin':
kwds["extra_compile_args"].append('-mmacosx-version-min=10.7')
kwds["extra_compile_args"].append('-stdlib=libc++')
ext_modules = [
Extension(
'dlreversi4.cpp',
sources=[
os.path.join('dlreversi4', 'cpp', 'State.cpp'),
os.path.join('dlreversi4', 'cpp', 'reversi4.cpp')
],
**kwds
)
]
setup(
name = 'deep-learning-reversi',
version = '0.0.1',
author = '',
author_email='',
packages = ['dlreversi4', 'dlreversi4.cpp'],
install_requires = ['pyprind', 'psutil'],
license = 'MIT',
zip_safe = False,
description='Deep Learning Reversi',
ext_modules=ext_modules
)
game.py
game.py を参照してください。
peformance_check.py
performance_check.py
from dlreversi4 import game
import dlreversi4.cpp as cpp
import pyprind
import argparse
def play(state, algorithm):
# ゲーム終了までのループ
while True:
# ゲーム終了時
if state.is_done():
break
# 次の状態の取得
action = algorithm(state)
state = state.next(action)
# 文字列表示
if args.verbose:
print('action: {}, dept: {}, {}'.format(action, state.depth, state.legal_actions))
print(state)
print()
if __name__ == '__main__':
'''
コマンドライン引数の取得
'''
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--trials', '-t', type=int, default=10000, help='number of trials')
parser.add_argument('--algorithm', '-a', type=str, default='random', help='algorithm')
parser.add_argument('--verbose', '-v', action='store_true')
args = parser.parse_args()
'''
試行
'''
bar = pyprind.ProgBar(args.trials, monitor=True, title='C++')
algorithm = eval('cpp.{}_action'.format(args.algorithm))
for _ in range(args.trials):
play(cpp.State(), algorithm)
bar.update()
print(bar)
bar = pyprind.ProgBar(args.trials, monitor=True, title='Python')
algorithm = eval('game.{}_action'.format(args.algorithm))
for _ in range(args.trials):
play(game.State(), algorithm)
bar.update()
print(bar)
インストール
上記プログラムをpipでPython環境にインストールします。setpy.pyのあるディレクトリーに移動し、下記コマンドを実行します。
$ pip install -e . --no-cache-dir
Obtaining file:///ai
Requirement already satisfied: pyprind in /root/.pyenv/versions/3.7.3/lib/python3.7/site-packages (from deep-learning-reversi==0.0.1) (2.11.2)
Requirement already satisfied: psutil in /root/.pyenv/versions/3.7.3/lib/python3.7/site-packages (from deep-learning-reversi==0.0.1) (5.6.3)
Installing collected packages: deep-learning-reversi
Found existing installation: deep-learning-reversi 0.0.1
Uninstalling deep-learning-reversi-0.0.1:
Successfully uninstalled deep-learning-reversi-0.0.1
Running setup.py develop for deep-learning-reversi
Successfully installed deep-learning-reversi
Note: you may need to restart the kernel to use updated packages.
実行結果
100回自己対局を行わせたところ、C++は1秒未満、Pythonでは15秒かかりました。かなりの高速化が図れました。
$ python performance_check.py
C++
0% [##############################] 100% | ETA: 00:00:00
Total time elapsed: 00:00:00
Title: C++
Started: 11/02/2019 09:38:17
Finished: 11/02/2019 09:38:17
Total time elapsed: 00:00:00
CPU %: 100.70
Memory %: 0.09
Python
0% [##############################] 100% | ETA: 00:00:00
Total time elapsed: 00:00:15
Title: Python
Started: 11/02/2019 09:38:17
Finished: 11/02/2019 09:38:33
Total time elapsed: 00:00:15
CPU %: 100.00
Memory %: 0.09