3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

streamlitで強化学習Webアプリを作る

Last updated at Posted at 2022-03-14

やりたいこと

streamlitを使って以下の機能を備えたtraderlの強化学習Webアプリを作ります。

  1. データの取得
  2. Agentの作成
  3. モデルのトレーニング
  4. 結果の表示
  5. モデルのセーブ

具体的には上の機能を持つ一つのAppクラスを作りサイドバーで選択できるようにする。

完成イメージ

image

必要なパッケージをインストールする

git clone https://github.com/komo135/trade-rl.git
cd trade-rl
pip install .

pip install streamlit

Webアプリの実行

以下のコマンドを入力してWebアプリを実行する

streamlit run https://raw.githubusercontent.com/komo135/traderl-web-app/main/github_app.py [ARGUMENTS]

streamlitとは

データ分析や機械学習のWebアプリケーションを簡単に作ることができるWebアプリケーションフレームワーク

traderlとは

強化学習を使ってFXや株式の取引を学習することができるpythonパッケージ

コード

アプリクラスとサイドバー関数の2つを作ります。

Appクラス

ホーム画面

ホーム画面に表示させるマークダウン

    def home(self):
        md = """
        # Traderl Web Application
        This web app is intuitive to [traderl](https://github.com/komo135/trade-rl).

        # How to Execute
        1. select data
            * Click on "select data" on the sidebar to choose your data.
        2. create agent
            * Click "create agent" on the sidebar and select an agent name and arguments to create an agent.
        3. training
            * Click on "training" on the sidebar to train your model.
        4. show results
            * Click "show results" on the sidebar to review the training results.
        """
        #マークダウンを画面に表示させる
        st.markdown(md)

データの選択

    def select_data(self):
        file = None
        
        #データの選択
        select = st.selectbox("", ("forex", "stock", "url or path", "file upload"))
        col1, col2 = st.columns(2)
        # クリックするとデータが読みこまれる
        load_file = st.button("load file")

        if select == "forex":
            symbol = col1.selectbox("", ("AUDJPY", "AUDUSD", "EURCHF", "EURGBP", "EURJPY", "EURUSD",
                                         "GBPJPY", "GBPUSD", "USDCAD", "USDCHF", "USDJPY", "XAUUSD"))
            timeframe = col2.selectbox("", ("m15", "m30", "h1", "h4", "d1"))
            if load_file:
                self.df = data.get_forex_data(symbol, timeframe)
        elif select == "stock":
            symbol = col1.text_input("", help="enter a stock symbol name")
            if load_file:
                self.df = data.get_stock_data(symbol)
        elif select == "url or path":
            file = col1.text_input("", help="enter url or local file path")
        elif select == "file upload":
            file = col1.file_uploader("", "csv")

        if load_file and file:
            st.write(file)
            self.df = pd.read_csv(file)

        if load_file:
            st.write("Data selected")

    def check_data(self):
        f"""
        # Select Data
        """
        #データが存在しているかを確認する
        if isinstance(self.df, pd.DataFrame):
            st.write("Data already exists")
            # データ既に存在していて初期化するかの確認のボタン
            if st.button("change data"):
                st.warning("data and agent have been initialized")
                self.df = None
                self.agent = None

        #データがない場合新しくデータを選択する
        if not isinstance(self.df, pd.DataFrame):
            self.select_data()

エージェントの作成

    #エージェントを作成する
    def create_agent(self, agent_name, args):
        agent_dict = {"dqn": dqn.DQN, "qrdqn":qrdqn.QRDQN}
        self.agent = agent_dict[agent_name](**args)

    #エージェントの選択、引数の選択
    def agent_select(self):
        # データが存在しない場合、警告を出す
        if not isinstance(self.df, pd.DataFrame):
            st.warning("data does not exist.\n"
                       "please select data")
            return None
        #エージェントの選択
        agent_name = st.selectbox("", ("dqn", "qrdqn"), help="select agent")

        """
        # select Args
        """
        # 使用可能なtensorflowモデルの選択
        col1, col2 = st.columns(2)
        network = col1.selectbox("select network", (nn.available_network))
        network_level = col2.selectbox("select network level", (f"b{i}" for i in range(8)))
        network += "_" + network_level
        self.model_name = network
        #その他の引数の選択
        col1, col2, col3, col4 = st.columns(4)
        lr = float(col1.text_input("lr", "1e-4"))
        n = int(col2.text_input("n", "3"))
        risk = float(col3.text_input("risk", "0.01"))
        pip_scale = int(col4.text_input("pip scale", "25"))
        col1, col2 = st.columns(2)
        gamma = float(col1.text_input("gamma", "0.99"))
        use_device = col2.selectbox("use device", ("cpu", "gpu", "tpu"))
        train_spread = float(col1.text_input("train_spread", "0.2"))
        spread = int(col2.text_input("spread", "10"))

        kwargs = {"df": self.df, "model_name": network, "lr": lr, "pip_scale": pip_scale, "n": n,
                  "use_device": use_device, "gamma": gamma, "train_spread": train_spread,
                  "spread": spread, "risk": risk}
       
        #ボタンをクリックするとエージェントが作成される
        if st.button("create agent"):
            self.create_agent(agent_name, kwargs)
            st.write("Agent created")

モデルのトレーニング

    #エージェントが存在しているかを確認する
    def agent_train(self)
        #存在している場合、ボタンをクリックするとモデルがトレーニングされる
        if self.agent:
            if st.button("training"):
                self.agent.train()
        #ない場合、警告を出す
        else:
            st.warning("agent does not exist.\n"
                       "please create agent")

トレーニング結果の表示

    def show_result(self):
        #エージェントが存在しているかを確認する
        if self.agent:
            self.agent.plot_result(self.agent.best_w)
        else:
            st.warning("agent does not exist.\n"
                       "please create agent")

モデルのセーブ

    def model_save(self):
        # セーブするファイル名を入力してボタンをクリックしてモデルをセーブする
        if self.agent:
            save_name = st.text_input("save name", self.model_name)
            if st.button("model save"):
                self.agent.model.save(save_name)
                st.write("Model saved.")
        else:
            st.warning("agent does not exist.\n"
                       "please create agent")

初期化

    @staticmethod
    def clear_cache():
        if st.button("initialize"):
            st.experimental_memo.clear()

## サイドバーの作成
```python
def sidebar():
    return st.sidebar.radio("", ("Home", "select data", "create agent", "training",
                                 "show results", "save model", "initialize"))

コードの実行

appをst.session_stateに保存してロードする理由

  • サイドバーにある要素を選択する度に最初化からの実行になる為、データやエージェントが保存されない
  • st.session_stateこの変数はページがロードされるまで保持される

appクラスとサイドバー関数を分ける理由

  • サイドバー関数がappクラス内にあるとappクラスst.session_stateからロードするとサイドバーが表示されなくなるから
if __name__ == "__main__":
    st.set_page_config(layout="wide", )

    if "app" in st.session_state:
        app = st.session_state["app"]
    else:
        app = App()

    select = sidebar()

    if select == "Home":
        app.home()

    if select == "select data":
        app.check_data()
    elif select == "create agent":
        app.agent_select()
    elif select == "training":
        app.agent_train()
    elif select == "save model":
        app.model_save()
    elif select == "show results":
        app.show_result()

    st.session_state["app"] = app
    if select == "initialize":
        app.clear_cache()

github

3
5
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?