LoginSignup
1
2

More than 3 years have passed since last update.

ChainerRLでbitmexの自動取引アルゴリズムに挑戦してみた

Last updated at Posted at 2019-09-01

以前は教師あり学習を用いて「仮想通貨の価格予測に挑戦」してみました。
しかし、結局のところ「どこで買ってどこで売ると儲かるの?」という学習をしないと意味が薄いということで、今回は強化学習に挑戦してみました。

お題目として考えたのは以下の2点です。

〇OHLCVデータがあれば他取引所・他金融商品(為替のヒストリカルデータを想定)にも転用可能であること
〇OHLCVデータから作成したインジケータを用いて、フレキシブルな学習が可能であること

というわけで、CSVの中から一部のカラムだけ読み込めるようにしてみたり
「入力するデータの列数・行数」に応じてobservation_spaceを計算させたりしています。

windows10(64bit)
intel Core i7-9700K
Memory 16GB
NVIDIA Geforce RTX2060

Python 3.6.9

chainer 6.3.0
chainerrl 0.7.0
gym 0.10.9
gym-ple 0.3
numpy 1.15.4

定義値として、いくつかの値を簡単に変更できるようにしました。

self.span
エージェントが過去何本分まで参照できるか
設定値では5分足288本=1日分となっています。

self.csv_filename
入力するCSVのファイル名

フルパス指定でも可、同フォルダ内であればファイル名のみで動きます。

self.use_cols
読み込むカラムの指定

CSVのイメージは下に掲載してあります。
※注文、決済、Q値の計算などに使うため、1つめのカラムには必ず終値を指定します。

以下ソースコードです。

環境定義用のファイル

bitmex_env.py
import sys

import gym
import numpy as np
import gym.spaces

class BitmexEnv(gym.Env):
    metadata = {'render.modes': ['human', 'ansi']}

    def __init__(self):

        #-------------↓各種定義値↓-------------
        self.span = 288 # 盤面に使う時間軸の本数

        self.start_funds = 10000 # シミュレーション時の初期資金

        self.slippage = 0.005 #  手数料・スリッページ

        self.csv_filename = "ticker.csv" # テスト用CSVデータのファイル名

        self.use_cols = (5,2,3,4,6) # テスト用CSVデータから読み込むデータ列(終値は先頭にすること)
        #-------------↑各種定義値↑-------------

        # 現在価格を取得
        self.input_data = self.get_price()

        # 行動範囲の設定(ロングを持つ、ショートを持つ、ノーポジにする)
        self.action_space = gym.spaces.Discrete(3)

        # 観測値の列数を計算する
        if np.ndim(self.input_data) > 1:
            obs_space = self.input_data.shape[0]
        elif np.ndim(self.input_data) == 1:
            obs_space = 1


        # 入力値の列数×行数 + ポジション(ロング、ショート、ノーポジ) + ポジション平均価格 + 現在価格
        self.observation_space = obs_space * self.span + 3

    def get_price(self):
        # 価格情報を読み込む関数
        # 読み込ませるCSVによって下記のように設定
        '''
        戻り値
        [0] = ['id']           # CSV1列目/行番号
        [1] = ['time_np']      # タイムスタンプ
        [2] = ['open_np']      # 始値
        [3] = ['high_np']      # 高値
        [4] = ['low_np']       # 安値
        [5] = ['close_np']     # 終値(発注・決済でも使用するため、必ずusecolsの先頭で読み込む)
        [6] = ['volume_np']    # 出来高
        '''
        data = np.loadtxt(self.csv_filename,       # 読み込みたいファイルのパス
        dtype="float32",
        delimiter=",",    # ファイルの区切り文字
        skiprows=0,       # 先頭の何行を無視するか(指定した行数までは読み込まない)
        usecols=self.use_cols # 読み込みたい列番号
        )
        input_data = data.transpose()

        return input_data

    def reset(self):
        # 環境を初期化する

        funds = self.start_funds # 総資産
        counter = self.span # 入力データセットの何行目から開始するか
        reward = 0 # 報酬
        signal = 0 # シグナル(-1:ショート/0:ノーポジ/1:ロング)

        self.action_count = 0
        self.random_count = 0

        # 空のndarrayを作成
        obs = np.empty(3, dtype='float32')
        obs[0] = 0
        obs[1] = 0
        # 末尾に盤面情報を付ける
        if np.ndim(self.input_data) > 1:
            obs[2] = self.input_data[0][counter-1]
        elif np.ndim(self.input_data) == 1:
            obs[2] = self.input_data[counter-1]
        obs = np.append(obs,self.make_tmp_data(obs, counter))

        return obs, funds, reward, counter

    def step(self, obs, signal, funds, reward, counter):
        # 観測局面を1つ進める関数

        # ポジション状況とシグナルを照合
        # ポジションありの場合
        if obs[0] != 0:

            # ポジションあり、かつシグナルが同方向の場合は何もしない
            if (obs[0] == 1 and signal == 1)\
                or (obs[0] == -1 and signal == -1):
                pass

            # ポジションあり、かつシグナルがStayの場合はポジションを決済する
            elif (obs[0] == 1 and signal == 0)\
                or (obs[0] == -1 and signal == 0):

                obs, funds, reward = self.close_order(obs, funds, reward)

            # ポジションあり、かつシグナルが逆方向の場合はポジションを決済してから逆注文を出す
            elif (obs[0] == 1 and signal == -1)\
                or (obs[0] == -1 and signal == 1):

                # ポジション決済関数を呼び出す
                obs, funds, reward = self.close_order(obs, funds, reward)

                # 注文処理
                obs = self.create_order(obs, signal)

        # ポジションなしの場合
        elif obs[0] == 0:

            # シグナルがStayの場合は何もしない
            if signal == 0:
                pass

            # シグナルがBuy、Sellの場合は新規注文を行う
            elif signal == 1 or signal == -1:

                # 注文処理
                obs = self.create_order(obs, signal)

        # 空のndarrayを作成
        next_obs = np.empty(3, dtype='float32')
        next_obs[0] = obs[0]
        next_obs[1] = obs[1]

        # 末尾に盤面情報を付ける
        if np.ndim(self.input_data) > 1:
            next_obs[2] = self.input_data[0][counter-1]
        elif np.ndim(self.input_data) == 1:
            next_obs[2] = self.input_data[counter-1]

        next_obs = np.append(next_obs,self.make_tmp_data(obs, counter))

        return next_obs, funds, reward, counter

    def create_order(self, obs, signal):
        # シグナルを受け取ってポジションを作成する関数
        # スリッページを計算して建玉時の平均金額を算出する

        obs[0] = signal

        if signal == 1:
            obs[1] = obs[2] + (obs[2] * self.slippage / 100)

        elif signal == -1:
            obs[1] = obs[2] - (obs[2] * self.slippage / 100)

        return obs

    def close_order(self, obs, funds, reward):
        # ポジションを決済する関数
        # 建玉の平均金額と現在の終値を使って資産情報を更新
        # 報酬は「建玉金額と決済金額の差額」を「建玉金額」で割ったもの(発散防止)

        if obs[0] == 1:
            funds = funds - obs[1] + obs[2]
            reward = reward + (- obs[1] + obs[2]) / obs[1]

        elif obs[0] == -1:
            funds = funds + obs[1] - obs[2]
            reward = reward + (obs[1] - obs[2]) / obs[1]

        obs[0] = 0
        obs[1] = 0

        return obs, funds, reward

    def make_tmp_data(self, obs, counter):
        # 観測値を作成する関数
        # 入力が1次元の場合
        if np.ndim(self.input_data) == 1:
            tmp_data = self.input_data[counter - self.span : counter]

        # 入力が多次元の場合
        elif np.ndim(self.input_data) > 1:
            tmp_data = self.input_data[:,counter - self.span : counter]
            tmp_data = tmp_data.reshape(1,tmp_data.size)

        return tmp_data

    def make_signal(self,action):
        # 学習内容に基づいてシグナルを作成する関数

        signal = action - 1
        self.action_count = self.action_count + 1

        return signal

    def random_action_func(self):
        # ランダム行動が選択された際のシグナルを作成する関数

        action = np.random.randint(0,3)
        self.random_count = self.random_count + 1
        self.action_count = self.action_count + 1

        return action

実行ファイル

bitmex_dqn.py
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np
import setting_value
import copy
import bitmex_env

class QFunction(chainer.Chain):

    def __init__(self, obs_size, n_actions, n_hidden_channels=50):
        super(QFunction, self).__init__(
            l0=L.Linear(obs_size, n_hidden_channels),
            l1=L.Linear(n_hidden_channels, n_hidden_channels),
            l2=L.Linear(n_hidden_channels, n_actions))

    def __call__(self, x, test=False):
        """
        Args:
            x (ndarray or chainer.Variable): An observation
            test (bool): a flag indicating whether it is in test mode
        """
        h = F.relu(self.l0(x))
        h = F.relu(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h))

# 環境の作成と中身の確認
env = bitmex_env.BitmexEnv()
print('observation space:', env.observation_space)
print('action space:', env.action_space)

# 環境の初期化
# resetはループ内でも呼ばれているが、obs_sizeの確認に必要なのでここでも呼んでいる
obs, funds, reward, counter = env.reset()
print('initial observation:', obs)


obs_size = env.observation_space
n_actions = env.action_space.n
q_func = QFunction(obs_size, n_actions)
# q_func.to_gpu()

optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)

# 報酬の割引率
gamma = 1

# 最初は100%ランダムに行動、10000ステップで50%までランダム率を下げる
explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(
    start_epsilon=1.0, end_epsilon=0.5, decay_steps=100000, random_action_func=env.random_action_func)

# 記憶用のバッファ
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)

# エージェントの作成
agent = chainerrl.agents.DoubleDQN(
    q_func, optimizer, replay_buffer, gamma, explorer,
    replay_start_size=50, update_interval=1,
    target_update_interval=10)

# エピソードの繰り返し回数
n_episodes = 10000

# エピソード開始
for i in range(1, n_episodes + 1):

    # 環境の初期化
    obs, funds, reward, counter = env.reset()
    reward = 0
    done = False
    last_state = None

    if np.ndim(env.input_data) == 1:
        loop_length = len(env.input_data)

    elif np.ndim(env.input_data) > 1:
        loop_length = len(env.input_data[0])

    for j in range(env.span + 1, loop_length):

        # アクションの作成
        action = agent.act_and_train(obs, reward)

        # シグナルの作成
        signal = env.make_signal(action)

        # ステップの実行
        obs, funds, reward, counter = env.step(obs, signal, funds, reward, counter)

    if i % 1 == 0:
        # 10エピソードごとに結果を出力
        print('episode:', i,
              'R:', reward,
              'F:', funds,
              'Random:', env.random_count,
              'statistics:', agent.get_statistics())
    agent.stop_episode_and_train(obs, reward, done)
print('Finished.')

CSVファイル
(行番号、タイムスタンプ、始値、高値、安値、終値、出来高)

ticker.csv
1,1555293600,578392,578925,577790,577893,5.75715
2,1555293900,578452,578998,577844,578430,166.607
3,1555294200,578430,578610,577817,578000,16.7402
4,1555294500,577818,578708,577817,578708,31.8986
5,1555294800,578568,578789,577600,577745,14.5544
6,1555295100,577741,578116,577500,577600,9.46102
7,1555295400,577600,578064,576999,577617,95.2678
8,1555295700,578000,578310,577006,578310,12.5638
9,1555296000,577940,578310,577703,577761,6.43372
10,1555296300,577759,578157,577500,577909,7.37702
11,1555296600,577911,578523,577768,578498,6.5487
12,1555296900,578505,578519,577915,578499,6.2731
13,1555297200,578496,578497,577900,577901,2.13157
14,1555297500,578026,578222,577000,577000,8.87903
15,1555297800,577661,577661,577000,577450,3.9577
16,1555298100,577450,577820,577450,577500,11.8182
17,1555298400,577540,577820,577409,577423,12.081
18,1555298700,577432,577949,577432,577663,14.9078
19,1555299000,577663,577663,577409,577409,54.0184
20,1555299300,577409,577465,576619,576967,35.5846
21,1555299600,576968,577399,576703,577367,17.3595
22,1555299900,577365,577372,576551,577335,32.0518
23,1555300200,576878,577600,576746,577486,19.9909
24,1555300500,577486,577950,576901,576920,45.4839
25,1555300800,577050,577596,576921,577196,18.9082
26,1555301100,577200,577866,576895,577814,20.5026
27,1555301400,577880,578194,577335,578123,20.293
28,1555301700,577649,577700,577308,577308,2.61217
29,1555302000,577876,578107,577454,577572,6.79842
30,1555302300,577574,578040,577574,578028,6.58562
31,1555302600,578028,578871,577664,578376,27.302
32,1555302900,578385,578498,577454,578398,11.5051
33,1555303200,578197,578199,577523,577999,27.7094
34,1555303500,577999,579000,577934,578216,26.192
35,1555303800,578605,579000,578218,578436,14.8183
36,1555304100,578435,578999,578431,578475,16.2454
37,1555304400,578807,580000,578485,579338,54.8877
38,1555304700,579339,579837,579001,579419,9.75745
39,1555305000,579419,579450,578786,579219,2.76562
40,1555305300,579384,579500,578813,579160,19.7781
41,1555305600,579000,579507,578904,579407,10.9652
42,1555305900,579313,579314,578823,578823,7.3185
43,1555306200,578823,579031,577500,578234,19.3347
44,1555306500,577604,578155,577310,577964,15.6121
45,1555306800,577981,578235,577537,578000,4.3552
46,1555307100,578000,578307,578000,578187,23.0904
47,1555307400,578187,578685,578152,578685,12.388
48,1555307700,578677,578677,578152,578152,3.12501
49,1555308000,578152,578512,578000,578124,4.82603
50,1555308300,578120,578677,578043,578501,1.79099

実際は手元にある16000行ほどのファイルを使用して動作テストをしています。

結果

〇OHLCV×288本を使って10000エピソード学習させたが、Rは伸びなかった。
〇正常に学習できていない?ネットワークの構築がおかしい?環境構築がおかしい?
メモ→12層、ニューロン数4倍×3層、2倍×2層にしたところ若干数値がマシに。

〇学習回数が足りない?入力値がよくない?
メモ→学習回数増やすと1週間あっても学習終わらなそうなのでとりあえずそれはパス。
OHLCVそれぞれ正規化してしまうと実使用時に同一環境を再現できなくなりそう。
(cryptowatchのAPI取得だと6000本、6000本の正規化データと同じになるとは思えないし…)
ステップ毎にエージェントに渡すデータ数を正規化してみる?

〇インジケータを積み込んでみたら結果は良化するのか?
メモ→インジケータ作成器を使えば比較的簡単に作成はできそう。
問題は↓に書いた互換性。

〇そもそもプログラム自体がイケてない?

〇やたら時間がかかる?(しかもGPUが動いている気配が見えない?)
メモ→numpyをcupyに置き換えてGPUに突っ込まないといけないらしいのだが、cupyにnp.appendの互換メソッドがないので再設計が必要。だと思われる。

などなど、反省点多々ありました。
ひとまず動くようにはなりましたが、今後とも改良を続けていきます。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2