0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

4x4リバーシ C++による高速化

Last updated at Posted at 2019-11-02

概要

4x4リバーシのPythonプログラムを作成しましたが、高速化するため、C++で記述し、pybind11を使ってPythonから呼び出せるようにします。ロジックはPythonプログラムと同一です。

環境

ホストOS Ubuntu18.04 にdockerをインストール、docker内でPythonを動かしています。

  • Ubuntu18.04
  • nvidia-driver-410
  • nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
  • Python 3.7.3
  • tensorflow-gpu 1.13.1
  • g++ (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0

プログラム

配置

deep-learning-reversi/
├ dlreversi4/
│ ├ cpp/
│ │ ├ Game.hpp
│ │ ├ State.cpp
│ │ └ State.cpp
│ └ game.py
├ performance_check.py
└ setup.py

State

ゲームの状態を表すクラスです。game.pyのStateクラスに当たります。
###ヘッダーファイル

State.hpp
#ifndef STATE_HPP_INCLUDED
#define STATE_HPP_INCLUDED

#include <stdint.h>
#include <array>
#include <vector>

class State
{
public:
    static const unsigned BOARD_SIZE = 4;
    static const unsigned BOARD_AREA = BOARD_SIZE * BOARD_SIZE;
    static const unsigned PASS = BOARD_SIZE * BOARD_SIZE;

    State();

    State(const std::array<unsigned, BOARD_AREA> pieces, const std::array<unsigned, BOARD_AREA> enemyPieces, const unsigned depth);

    ~State() {}

    // 負けかどうか
    bool isLose() const;
    // 引き分けかどうか
    bool isDraw() const;
    // ゲーム終了かどうか
    bool isDone() const;
    // 先手かどうか
    bool isFirstPlayer() const;
    // 次の状態の取得
    State next(const unsigned& action);
    // 文字列化
    std::string toString() const;

    // 石の数
    unsigned piecesCount() const;
    unsigned enemyPiecesCount() const;

    // 石の配置
    std::array<unsigned, BOARD_AREA> pieces;
    std::array<unsigned, BOARD_AREA> enemyPieces;
    // 石の配置(ビット表現)
    std::uint16_t piecesBit;
    std::uint16_t enemyPiecesBit;
    // 手数
    unsigned depth;
    // 合法手
    std::vector<unsigned> legalActions;

private:
    static const std::uint16_t MASK_EDGE		= 0x0660;
    static const std::uint16_t MASK_VERTICAL	= 0x6666;
    static const std::uint16_t MASK_HORIZONTAL	= 0x0FF0;

    static const unsigned DIR_LEFT_OBRIQUE	= 5;
    static const unsigned DIR_RIGHT_OBLIQUE	= 3;
    static const unsigned DIR_HORIZONTAL	= 1;
    static const unsigned DIR_VERTICAL		= 4;

    // 石の数の取得
    unsigned countBit(const std::uint16_t& bit) const;
    // 合法手の取得
    std::uint16_t generateLegal() const;
    std::uint16_t generateSomeLegal(const std::uint16_t& pieces_bit, const std::uint16_t& enemy_pieces_bit, const unsigned& direction) const;
    // 任意のマスが合法手かどうか
    bool isLegalAction(const std::uint16_t& actionBit) const;
    // 石を置く。
    void move(const std::uint16_t& actionBit);
    std::uint16_t generateSomeFlipped(
        const std::uint16_t& piecesBit, const std::uint16_t& enemyPiecesBit, const std::uint16_t& actionBit, const unsigned& direction) const;
    // 石の配置を配列に変換
    std::array<unsigned, BOARD_AREA> bitToArray(const std::uint16_t& bit) const;
    // 石の配列表現をビットに変換
    std::uint16_t arrayToBit(const std::array<unsigned, BOARD_AREA>& array) const;
    // 合法手を配列に変換
    std::vector<unsigned> legalActionToAraay(const std::uint16_t& bit) const;

    // 連続パスによる終了
    bool passEnd;
    // 合法手(ビット表現)
    std::uint16_t legalActionsBit;
};

#endif

ソースファイル

State.cpp
#include "State.hpp"
#include <iostream>
#include <sstream>

// ゲーム状態
State::State()
{
    // 連続パスによる終了
    passEnd = false;

    // 石の配置
    this->pieces = { 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0 };
    this->enemyPieces = { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 };
    this->piecesBit = arrayToBit(this->pieces);
    this->enemyPiecesBit = arrayToBit(this->enemyPieces);

    // 手数
    this->depth = 0;

    // 合法手の設定
    this->legalActionsBit = generateLegal();
    this->legalActions = legalActionToAraay(this->legalActionsBit);
}

// ゲーム状態
State::State(const std::array<unsigned, BOARD_AREA> pieces, const std::array<unsigned, BOARD_AREA> enemyPieces, const unsigned depth)
{
    // 連続パスによる終了
    passEnd = false;

    // 石の配置
    this->pieces = pieces;
    this->enemyPieces = enemyPieces;
    this->piecesBit = arrayToBit(this->pieces);
    this->enemyPiecesBit = arrayToBit(this->enemyPieces);

    // 手数
    this->depth = depth;

    // 合法手の設定
    this->legalActionsBit = generateLegal();
    this->legalActions = legalActionToAraay(this->legalActionsBit);
}

// 負けかどうか
bool State::isLose() const
{
    return isDone() && (countBit(this->piecesBit) < countBit(this->enemyPiecesBit));
}

// 引き分けかどうか
bool State::isDraw() const
{
    return isDone() && (countBit(this->piecesBit) == countBit(this->enemyPiecesBit));
}

// ゲーム終了かどうか
bool State::isDone() const
{
    return (countBit(this->piecesBit | this->enemyPiecesBit) == BOARD_AREA) || passEnd;
}

// 先手かどうか
bool State::isFirstPlayer() const
{
    return this->depth % 2 == 0;
}
// 次の状態の取得
State State::next(const unsigned& action)
{
    std::uint16_t actionBit = 0x8000 >> action;
    State state(this->pieces, this->enemyPieces, this->depth + 1);

    if (actionBit != 0 && isLegalAction(actionBit))
        state.move(actionBit);

    std::array<unsigned, BOARD_AREA> w = state.pieces;
    std::uint16_t wb = state.piecesBit;
    state.pieces = state.enemyPieces;
    state.piecesBit = state.enemyPiecesBit;
    state.enemyPieces = w;
    state.enemyPiecesBit = wb;

    // 合法手の設定
    state.legalActionsBit = state.generateLegal();
    state.legalActions = legalActionToAraay(state.legalActionsBit);

    // 2回連続パス判定
    if (actionBit == 0 && state.legalActionsBit == 0)
        state.passEnd = true;

    return state;
}

// 石の数
unsigned State::piecesCount() const
{
    return countBit(this->piecesBit);
}

unsigned State::enemyPiecesCount() const
{
    return countBit(this->enemyPiecesBit);
}

// 石の数の取得
unsigned State::countBit(const std::uint16_t& bit) const
{
    std::uint16_t count = (bit & 0x5555) + ((bit >> 1) & 0x5555);
    count = (count & 0x3333) + ((count >> 2) & 0x3333);
    count = (count & 0x0f0f) + ((count >> 4) & 0x0f0f);
    return (count & 0x00ff) + ((count >> 8) & 0x00ff);
}

// 合法手の取得
std::uint16_t State::generateLegal() const
{
    return ~(this->piecesBit | this->enemyPiecesBit)
            & (generateSomeLegal(this->piecesBit, this->enemyPiecesBit & MASK_EDGE, DIR_LEFT_OBRIQUE)
                |  generateSomeLegal(this->piecesBit, this->enemyPiecesBit & MASK_EDGE, DIR_RIGHT_OBLIQUE)
                |  generateSomeLegal(this->piecesBit, this->enemyPiecesBit & MASK_VERTICAL, DIR_HORIZONTAL)
                |  generateSomeLegal(this->piecesBit, this->enemyPiecesBit & MASK_HORIZONTAL, DIR_VERTICAL));
}

inline std::uint16_t State::generateSomeLegal(const std::uint16_t& pieces_bit, const std::uint16_t& enemy_pieces_bit, const unsigned& direction) const
{
    std::uint16_t flipped = ((pieces_bit << direction) | (pieces_bit >> direction)) & enemy_pieces_bit;
    for (unsigned i = 0; i < 6; i++)
        flipped |= ((flipped << direction) | (flipped >> direction)) & enemy_pieces_bit;

    return flipped << direction | flipped >> direction;
}

// 任意のマスが合法手かどうか
inline bool State::isLegalAction(const std::uint16_t& actionBit) const
{
    return actionBit && this->legalActionsBit != 0;
}

// 石を置く。
void State::move(const std::uint16_t& actionBit)
{
    std::uint16_t flipped = generateSomeFlipped(this->piecesBit, this->enemyPiecesBit & MASK_EDGE, actionBit, DIR_LEFT_OBRIQUE);
    flipped |= generateSomeFlipped(this->piecesBit, this->enemyPiecesBit & MASK_EDGE, actionBit, DIR_RIGHT_OBLIQUE);
    flipped |= generateSomeFlipped(this->piecesBit, this->enemyPiecesBit & MASK_VERTICAL, actionBit, DIR_HORIZONTAL);
    flipped |= generateSomeFlipped(this->piecesBit, this->enemyPiecesBit & MASK_HORIZONTAL, actionBit, DIR_VERTICAL);

    this->piecesBit = this->piecesBit | actionBit | flipped;
    this->enemyPiecesBit = this->enemyPiecesBit ^ flipped;
    this->pieces = bitToArray(this->piecesBit);
    this->enemyPieces = bitToArray(this->enemyPiecesBit);
}

inline std::uint16_t State::generateSomeFlipped(
    const std::uint16_t& piecesBit, const std::uint16_t& enemyPiecesBit, const std::uint16_t& actionBit, const unsigned& direction) const
{
    const uint16_t leftEnemy = (actionBit << direction) & enemyPiecesBit;
    const uint16_t rightEnemy = (actionBit >> direction) & enemyPiecesBit;
    const uint16_t leftSelf = (piecesBit << direction) & enemyPiecesBit;
    const uint16_t rightSelf = (piecesBit >> direction) & enemyPiecesBit;

    return ((leftEnemy & (rightSelf | (rightSelf >> direction))) |
        ((leftEnemy << direction) & rightSelf) | (rightEnemy & (leftSelf | (leftSelf << direction))) |
        ((rightEnemy >> direction) & leftSelf));
}

// 石のビット表現を配列に変換
inline std::array<unsigned, State::BOARD_AREA> State::bitToArray(const std::uint16_t& bit) const
{
    std::array<unsigned, BOARD_AREA> array;
    std::uint16_t mask = 0x8000;
    for (unsigned i = 0; i < BOARD_AREA; ++i)
    {
        if ((bit & mask) == 0)
            array[i] = 0;
        else
            array[i] = 1;
        mask = mask >> 1;
    }

    return array;
}

// 石の配列表現をビットに変換
inline std::uint16_t State::arrayToBit(const std::array<unsigned, BOARD_AREA>& array) const
{
    std::uint16_t bit = 0;
    std::uint16_t mask = 0x8000;
    for (unsigned i = 0; i < array.size(); ++i)
    {
        if (array[i] > 0)
            bit |= mask;
        mask = mask >> 1;
    }

    return bit;
}

// 合法手を配列に変換
inline std::vector<unsigned> State::legalActionToAraay(const std::uint16_t& bit) const
{
    if (bit == 0)
    {
        std::vector<unsigned> vec(1);
        vec[0] = PASS;
        return vec;
    }

    std::vector<unsigned> vec(countBit(bit));
    std::uint16_t mask = 0x8000;
    unsigned index = 0;
    for (unsigned i = 0; i < BOARD_AREA; ++i)
    {
        if ((bit & mask) != 0)
            vec[index++] = i;
        mask = mask >> 1;
    }

    return vec;
}

// 文字列化
std::string State::toString() const
{
    std::uint16_t mask = 0x8000;
    std::array<std::string, 2> ox;
    if (isFirstPlayer())
        ox = { "o", "x" };
    else
        ox = { "x", "o" };

    std::stringstream ss;
    for (unsigned i = 0; i < BOARD_AREA; i++)
    {
        if ((this->piecesBit & mask) != 0)
            ss << ox[0];
        else if ((this->enemyPiecesBit & mask) != 0)
            ss << ox[1];
        else
            ss << '-';
        if (i % 4 == 3)
            ss << std::endl;
        mask = mask >> 1;
    }

    return ss.str();
}

Game

Game.hpp
#ifndef GAME_HPP_INCLUDED
#define GAME_HPP_INCLUDED

#include "State.hpp"

#include <random>

// ランダムで行動選択
unsigned randomAction(const State& state)
{
    const unsigned size = state.legalActions.size();

    if (size >= 2)
    {
        std::random_device rnd; // 非決定的な乱数生成器
        std::mt19937 mt(rnd()); // メルセンヌ・ツイスタの32ビット版
        std::uniform_real_distribution<> rand01(0, size - 1);
        unsigned index = rand01(mt);

        return state.legalActions[index];
    }
    else if (size == 1)
    {
        return state.legalActions[0];
    }
    else
    {
        return State::PASS;
    }
}

#endif

reveri4.cpp

reversi4.cpp
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cstdint>
#include <string>
#include "State.hpp"
#include "Game.hpp"

namespace py = pybind11;

PYBIND11_MODULE(cpp, m)
{
    m.def("random_action", &randomAction);

    py::class_<State> state(m, "State");
    state
        .def(py::init<>())
        .def(py::init<std::array<unsigned, 16>, std::array<unsigned, 16>, unsigned>())
        .def_readonly("pieces", &State::pieces)
        .def_readonly("enemy_pieces", &State::enemyPieces)
        .def_readonly("pieces_bit", &State::piecesBit)
        .def_readonly("enemy_pieces_bit", &State::enemyPiecesBit)
        .def_readonly("depth", &State::depth)
        .def_readonly("legal_actions", &State::legalActions)
        .def("is_lose", &State::isLose)
        .def("is_draw", &State::isDraw)
        .def("is_done", &State::isDone)
        .def("is_first_player", &State::isFirstPlayer)
        .def("next", &State::next)
        .def("pieces_count", &State::piecesCount)
        .def("enemy_pieces_count", &State::enemyPiecesCount)
        .def("__str__", &State::toString);

    m.doc() = "cpp eversi4 plugin";

#ifdef VERSION_INFO
    m.attr("__version__") = VERSION_INFO;
#else
    m.attr("__version__") = "dev";
#endif
}

setup.py

setup.py
from setuptools import setup
from setuptools import find_packages

import os, sys
import pybind11
from setuptools import Extension
from distutils import sysconfig

# Remove the "-Wstrict-prototypes" compiler option, which isn't valid for C++.
import distutils.sysconfig
cfg_vars = distutils.sysconfig.get_config_vars()
for key, value in cfg_vars.items():
    if type(value) == str:
        cfg_vars[key] = value.replace("-Wstrict-prototypes", "")

kwds = dict(
    extra_compile_args=['-std=c++11'],
    include_dirs=[
        os.path.dirname(pybind11.get_include(True)),
        os.path.dirname(pybind11.get_include(False))
    ],
)

if sys.platform == 'darwin':
    kwds["extra_compile_args"].append('-mmacosx-version-min=10.7')
    kwds["extra_compile_args"].append('-stdlib=libc++')


ext_modules = [
    Extension(
        'dlreversi4.cpp',
        sources=[
            os.path.join('dlreversi4', 'cpp', 'State.cpp'),
            os.path.join('dlreversi4', 'cpp', 'reversi4.cpp')
        ],
        **kwds
    )
]

setup(
    name = 'deep-learning-reversi',
    version = '0.0.1',
    author = '',
    author_email='',
    packages = ['dlreversi4', 'dlreversi4.cpp'],
    install_requires = ['pyprind', 'psutil'],
    license = 'MIT',
    zip_safe = False,
    description='Deep Learning Reversi',
    ext_modules=ext_modules
)

game.py

game.py を参照してください。

peformance_check.py

performance_check.py
from dlreversi4 import game
import dlreversi4.cpp as cpp
import pyprind
import argparse

def play(state, algorithm):
    # ゲーム終了までのループ
    while True:
        # ゲーム終了時
        if state.is_done():
            break

        # 次の状態の取得
        action = algorithm(state)
        state = state.next(action)

        # 文字列表示
        if args.verbose:
            print('action: {}, dept: {}, {}'.format(action, state.depth, state.legal_actions))
            print(state)
            print()

if __name__ == '__main__':
    '''
    コマンドライン引数の取得
    '''
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--trials', '-t', type=int, default=10000, help='number of trials')
    parser.add_argument('--algorithm', '-a', type=str, default='random', help='algorithm')
    parser.add_argument('--verbose', '-v', action='store_true')
    args = parser.parse_args()

    '''
    試行
    '''
    bar = pyprind.ProgBar(args.trials, monitor=True, title='C++')
    algorithm = eval('cpp.{}_action'.format(args.algorithm))
    for _ in range(args.trials):
        play(cpp.State(), algorithm)
        bar.update()
    print(bar)

    bar = pyprind.ProgBar(args.trials, monitor=True, title='Python')
    algorithm = eval('game.{}_action'.format(args.algorithm))
    for _ in range(args.trials):
        play(game.State(), algorithm)
        bar.update()
    print(bar)

インストール

上記プログラムをpipでPython環境にインストールします。setpy.pyのあるディレクトリーに移動し、下記コマンドを実行します。

$ pip install -e . --no-cache-dir
Obtaining file:///ai
Requirement already satisfied: pyprind in /root/.pyenv/versions/3.7.3/lib/python3.7/site-packages (from deep-learning-reversi==0.0.1) (2.11.2)
Requirement already satisfied: psutil in /root/.pyenv/versions/3.7.3/lib/python3.7/site-packages (from deep-learning-reversi==0.0.1) (5.6.3)
Installing collected packages: deep-learning-reversi
  Found existing installation: deep-learning-reversi 0.0.1
    Uninstalling deep-learning-reversi-0.0.1:
      Successfully uninstalled deep-learning-reversi-0.0.1
  Running setup.py develop for deep-learning-reversi
Successfully installed deep-learning-reversi
Note: you may need to restart the kernel to use updated packages.

実行結果

100回自己対局を行わせたところ、C++は1秒未満、Pythonでは15秒かかりました。かなりの高速化が図れました。

$ python performance_check.py 
C++
0% [##############################] 100% | ETA: 00:00:00
Total time elapsed: 00:00:00
Title: C++
  Started: 11/02/2019 09:38:17
  Finished: 11/02/2019 09:38:17
  Total time elapsed: 00:00:00
  CPU %: 100.70
  Memory %: 0.09
Python
0% [##############################] 100% | ETA: 00:00:00
Total time elapsed: 00:00:15
Title: Python
  Started: 11/02/2019 09:38:17
  Finished: 11/02/2019 09:38:33
  Total time elapsed: 00:00:15
  CPU %: 100.00
  Memory %: 0.09
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?