今回はpythonを用いて日経平均株価の推移を予測してみようと思います。コードの詳しい解説などは参考にさせていただいた記事にありますのでこちらをご覧ください。
この記事では実装した際に躓いた点や、ちょっとした工夫を解説します。
##サンプルコード
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
from fbprophet import Prophet
from fbprophet.diagnostics import cross_validation
from fbprophet.diagnostics import performance_metrics
from fbprophet.plot import plot_cross_validation_metric
data = pd.DataFrame()
args = sys.argv
file_name = args[1] #ここでデータファイルを読み込む
data2 = pd.read_csv(file_name, skiprows=1, header=None, names=['ds','Open','High','Low','Close','Adj_Close','Volume'])
data3 = data2.dropna(how='any')
data = data.append(data3)
plt.style.use('ggplot')
fig = plt.figure(figsize=(18,8))
ax_Nikkei = fig.add_subplot(224)
ax_Nikkei.plot(data.loc["1965-04-21": ,['Adj_Close']], label='Nikkei', color='r')
ax_Nikkei.set_title('NIKKEI')
data['Adj_Close_log'] = np.log(data.Adj_Close).diff()
data.head()
data.tail()
fig = plt.figure(figsize=(18,8))
ax_Nikkei_log = fig.add_subplot(224)
ax_Nikkei_log.plot(data.loc["1965-04-21": ,['Adj_Close_log']], label='log diff', color='b')
ax_Nikkei_log.set_title('log earning rate')
model = Prophet()
model.fit(data.rename(columns={'Adj_Close':'y'}))
future_data = model.make_future_dataframe(periods=365, freq= 'd')
forecast_data = model.predict(future_data)
fig = model.plot(forecast_data)
df_cv = cross_validation(model, initial='730 days', period='180 days', horizon = '365 days')
df_cv.head()
df_p = performance_metrics(df_cv)
df_p.head()
#MSEをプロットする関数をデバッグのために再定義
def plot_cross_validation_metric(
df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6)
):
"""Plot a performance metric vs. forecast horizon from cross validation.
Cross validation produces a collection of out-of-sample model predictions
that can be compared to actual values, at a range of different horizons
(distance from the cutoff). This computes a specified performance metric
for each prediction, and aggregated over a rolling window with horizon.
This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
rolling_window is the proportion of data included in the rolling window of
aggregation. The default value of 0.1 means 10% of data are included in the
aggregation for computing the metric.
As a concrete example, if metric='mse', then this plot will show the
squared error for each cross validation prediction, along with the MSE
averaged over rolling windows of 10% of the data.
Parameters
----------
df_cv: The output from fbprophet.diagnostics.cross_validation.
metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'coverage'].
rolling_window: Proportion of data to use for rolling average of metric.
In [0, 1]. Defaults to 0.1.
ax: Optional matplotlib axis on which to plot. If not given, a new figure
will be created.
Returns
-------
a matplotlib figure.
"""
if ax is None:
fig = plt.figure(facecolor='w', figsize=figsize)
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
# Get the metric at the level of individual predictions, and with the rolling window.
df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=0)
df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window)
# Some work because matplotlib does not handle timedelta
# Target ~10 ticks.
tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10.
# Find the largest time resolution that has <1 unit per bin.
dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns']
dt_names = [
'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds',
'nanoseconds'
]
dt_conversions = [
24 * 60 * 60 * 10 ** 9,
60 * 60 * 10 ** 9,
60 * 10 ** 9,
10 ** 9,
10 ** 6,
10 ** 3,
1.,
]
for i, dt in enumerate(dts):
if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'):
break
x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i])
x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i])
ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
ax.plot(x_plt_h, df_h[metric], '-', c='b')
ax.grid(True)
ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
ax.set_ylabel(metric)
return fig
# 日経平均株価をそのままモデリング(MSEが高くなり当てはまりが悪い)
fig = plot_cross_validation_metric(df_cv, metric='mse')#MSEをプロット
model2 = Prophet()
model2.fit(data.rename(columns={'Adj_Close_log':'y'}))
future_data2 = model2.make_future_dataframe(periods=365, freq= 'd')
forecast_data2 = model2.predict(future_data2)
fig2 = model2.plot(forecast_data2)
df_cv2 = cross_validation(model2, initial='730 days', period='180 days', horizon = '365 days')
df_cv2.head()
df_p2 = performance_metrics(df_cv2)
df_p2.head()
# 株価ではなく対数収益率をモデリング(当てはまりが高い)
fig = plot_cross_validation_metric(df_cv2, metric='mse')#MSEをプロット
plt.show()
##サンプルデータ
日経平均株価データはYAHOO!FINANCEのWebサイトからダウンロード可能。
リンク:https://finance.yahoo.com/quote/%5EN225/history?ltr=1
取得したい期間を自由に指定してcsv形式でダウンロード。(期間が長いほど予測が正確になります)
##ライブラリのインポート
・numpy・・・数値計算を効率的に行うためのライブラリ。
・pandas・・・csv形式のデータをPython上でデータフレームとして扱うために必要なライブラリ。
・matplotlib.pyplot・・・Pythonでデータを分析した結果をグラフで表示するためのライブラリ。
・Prophet・・・Pythonで時系列解析を実装できるFacebookが開発したライブラリ(pipでインストールが必要)
・cross_validation・・・データに対してCV(交差検証)を実行するライブラリ
・performance_metrics・・・予測モデルのパフォーマンスを示す値をまとめた表を出力するライブラリ。
・plot_cross_validation_metric・・・CV(交差検証)のパフォーマンスを示す値をまとめた表を出力するライブラリ。
・sys・・・予測するcsv形式のデータを選択する際に用いるライブラリ
###Prophetのインストール
Prophetのインストールには```$ pip install fbprophet
しかし、私はここで ```ModuleNotFoundError: No module named 'pystan'```というエラーが発生しました。pystanのライブラリが入っていないのかな?と思い```$ pip list```で見てみましたが、pystanはしっかりインストールされている、、、
いろいろ調べているとどうやらpystanのバージョンが原因のよう。3.0以降だとこのエラーが発生するらしい。インストールされていたのはバージョン3.0.2でした。
これを削除し、```$ pip install pystan==2.19.1.1```
でバージョン2.19をインストール
再度```$ pip install fbprophet```を実行するとインストールに成功しました。
##データの読み込みを簡単にする
参考にさせていただいた記事では直接データファイル名をコードに書いていました。
```python:stock.py
data = pd.DataFrame()
file_name = 'N225.csv'#ここ
data2 = pd.read_csv(file_name, skiprows=1, header=None, names=['ds','Open','High','Low','Close','Adj_Close','Volume'])
data3 = data2.dropna(how='any')
data = data.append(data3)
print(data)
しかし、いちいちコードを編集するのは面倒なので、今回はターミナルでデータファイルを指定して読み込ませます。
その際に使うライブラリがsysです。
import sys
args = sys.argv
file_name = args[1] #ここでデータファイルを読み込む
data2 = pd.read_csv(file_name, skiprows=1, header=None, names=['ds','Open','High','Low','Close','Adj_Close','Volume'])
このようにすることでターミナルで
$ python stock.py N225.csv
と入力すると同じように実行してくれます。
具体的には、args[0]はstock.py, args[1]はN225.csv にあたります。
これで違うデータファイルを読み込むときも
$ python stock.py データファイル名
で簡単に実行できます
##実行結果
###拡大したもの
他にもグラフが出力されますが専門的なものなので割愛します。
##まとめ
今回は初めてpythonで株価予測をしましたが、ライブラリを使えば比較的簡単に実装できました。コードの解説などはこちらに詳しく載っていますので参考にしてください。
今回はこれで終わります。ありがとうございました!
twitter: @siron_www 日常のこともつぶやいてるので友達感覚でフォローしてください😁