やりたいこと
streamlitを使って以下の機能を備えたtraderlの強化学習Webアプリを作ります。
- データの取得
- Agentの作成
- モデルのトレーニング
- 結果の表示
- モデルのセーブ
具体的には上の機能を持つ一つのAppクラスを作りサイドバーで選択できるようにする。
完成イメージ
必要なパッケージをインストールする
git clone https://github.com/komo135/trade-rl.git
cd trade-rl
pip install .
pip install streamlit
Webアプリの実行
以下のコマンドを入力してWebアプリを実行する
streamlit run https://raw.githubusercontent.com/komo135/traderl-web-app/main/github_app.py [ARGUMENTS]
streamlitとは
データ分析や機械学習のWebアプリケーションを簡単に作ることができるWebアプリケーションフレームワーク
traderlとは
強化学習を使ってFXや株式の取引を学習することができるpythonパッケージ
コード
アプリクラスとサイドバー関数の2つを作ります。
Appクラス
ホーム画面
ホーム画面に表示させるマークダウン
def home(self):
md = """
# Traderl Web Application
This web app is intuitive to [traderl](https://github.com/komo135/trade-rl).
# How to Execute
1. select data
* Click on "select data" on the sidebar to choose your data.
2. create agent
* Click "create agent" on the sidebar and select an agent name and arguments to create an agent.
3. training
* Click on "training" on the sidebar to train your model.
4. show results
* Click "show results" on the sidebar to review the training results.
"""
#マークダウンを画面に表示させる
st.markdown(md)
データの選択
def select_data(self):
file = None
#データの選択
select = st.selectbox("", ("forex", "stock", "url or path", "file upload"))
col1, col2 = st.columns(2)
# クリックするとデータが読みこまれる
load_file = st.button("load file")
if select == "forex":
symbol = col1.selectbox("", ("AUDJPY", "AUDUSD", "EURCHF", "EURGBP", "EURJPY", "EURUSD",
"GBPJPY", "GBPUSD", "USDCAD", "USDCHF", "USDJPY", "XAUUSD"))
timeframe = col2.selectbox("", ("m15", "m30", "h1", "h4", "d1"))
if load_file:
self.df = data.get_forex_data(symbol, timeframe)
elif select == "stock":
symbol = col1.text_input("", help="enter a stock symbol name")
if load_file:
self.df = data.get_stock_data(symbol)
elif select == "url or path":
file = col1.text_input("", help="enter url or local file path")
elif select == "file upload":
file = col1.file_uploader("", "csv")
if load_file and file:
st.write(file)
self.df = pd.read_csv(file)
if load_file:
st.write("Data selected")
def check_data(self):
f"""
# Select Data
"""
#データが存在しているかを確認する
if isinstance(self.df, pd.DataFrame):
st.write("Data already exists")
# データ既に存在していて初期化するかの確認のボタン
if st.button("change data"):
st.warning("data and agent have been initialized")
self.df = None
self.agent = None
#データがない場合新しくデータを選択する
if not isinstance(self.df, pd.DataFrame):
self.select_data()
エージェントの作成
#エージェントを作成する
def create_agent(self, agent_name, args):
agent_dict = {"dqn": dqn.DQN, "qrdqn":qrdqn.QRDQN}
self.agent = agent_dict[agent_name](**args)
#エージェントの選択、引数の選択
def agent_select(self):
# データが存在しない場合、警告を出す
if not isinstance(self.df, pd.DataFrame):
st.warning("data does not exist.\n"
"please select data")
return None
#エージェントの選択
agent_name = st.selectbox("", ("dqn", "qrdqn"), help="select agent")
"""
# select Args
"""
# 使用可能なtensorflowモデルの選択
col1, col2 = st.columns(2)
network = col1.selectbox("select network", (nn.available_network))
network_level = col2.selectbox("select network level", (f"b{i}" for i in range(8)))
network += "_" + network_level
self.model_name = network
#その他の引数の選択
col1, col2, col3, col4 = st.columns(4)
lr = float(col1.text_input("lr", "1e-4"))
n = int(col2.text_input("n", "3"))
risk = float(col3.text_input("risk", "0.01"))
pip_scale = int(col4.text_input("pip scale", "25"))
col1, col2 = st.columns(2)
gamma = float(col1.text_input("gamma", "0.99"))
use_device = col2.selectbox("use device", ("cpu", "gpu", "tpu"))
train_spread = float(col1.text_input("train_spread", "0.2"))
spread = int(col2.text_input("spread", "10"))
kwargs = {"df": self.df, "model_name": network, "lr": lr, "pip_scale": pip_scale, "n": n,
"use_device": use_device, "gamma": gamma, "train_spread": train_spread,
"spread": spread, "risk": risk}
#ボタンをクリックするとエージェントが作成される
if st.button("create agent"):
self.create_agent(agent_name, kwargs)
st.write("Agent created")
モデルのトレーニング
#エージェントが存在しているかを確認する
def agent_train(self)
#存在している場合、ボタンをクリックするとモデルがトレーニングされる
if self.agent:
if st.button("training"):
self.agent.train()
#ない場合、警告を出す
else:
st.warning("agent does not exist.\n"
"please create agent")
トレーニング結果の表示
def show_result(self):
#エージェントが存在しているかを確認する
if self.agent:
self.agent.plot_result(self.agent.best_w)
else:
st.warning("agent does not exist.\n"
"please create agent")
モデルのセーブ
def model_save(self):
# セーブするファイル名を入力してボタンをクリックしてモデルをセーブする
if self.agent:
save_name = st.text_input("save name", self.model_name)
if st.button("model save"):
self.agent.model.save(save_name)
st.write("Model saved.")
else:
st.warning("agent does not exist.\n"
"please create agent")
初期化
@staticmethod
def clear_cache():
if st.button("initialize"):
st.experimental_memo.clear()
## サイドバーの作成
```python
def sidebar():
return st.sidebar.radio("", ("Home", "select data", "create agent", "training",
"show results", "save model", "initialize"))
コードの実行
appをst.session_state
に保存してロードする理由
- サイドバーにある要素を選択する度に最初化からの実行になる為、データやエージェントが保存されない
-
st.session_state
この変数はページがロードされるまで保持される
appクラスとサイドバー関数を分ける理由
- サイドバー関数がappクラス内にあるとappクラス
st.session_state
からロードするとサイドバーが表示されなくなるから
if __name__ == "__main__":
st.set_page_config(layout="wide", )
if "app" in st.session_state:
app = st.session_state["app"]
else:
app = App()
select = sidebar()
if select == "Home":
app.home()
if select == "select data":
app.check_data()
elif select == "create agent":
app.agent_select()
elif select == "training":
app.agent_train()
elif select == "save model":
app.model_save()
elif select == "show results":
app.show_result()
st.session_state["app"] = app
if select == "initialize":
app.clear_cache()
github