はじめに
株式会社クリエイスのモトキです。
前回、pandasでグラフを表示しました。
Anaconda環境でPyTorch 〜株価予想〜 #01 環境構築編
Anaconda環境でPyTorch 〜株価予想〜 #02 基礎知識・学習編
Anaconda環境でPyTorch 〜株価予想〜 #03 予測編
Anaconda環境でPyTorch 〜株価予想〜 #04 予測(リベンジ)編 (今回)
Anaconda環境でPyTorch 〜株価予想〜 #05 表示編
やったこと
とりあえず分からなくてもいいから動かそうということで新たにソースをお借りしました。
https://github.com/melissa135/mlp_stock
文字コードが中国語のものだったり、スペースタブが混じっているので修正してから始めます。
このソースではCSVから取得しています。しかもpandasで取り込んでおりquandlに変えても同じ様に扱えます。
date,open,high,low,close,volume,close_change,volume_change
"2005-01-04,二",3051.24,3051.24,3016.26,3025.42,435050970,0,0
...
学習
前回同様に、Appleの株価を拾ってくるようにsample_set.pyを修正します
import torch
import torch.utils.data as data
import quandl
class SampleSet(data.Dataset):
def __init__(self):
quandl.ApiConfig.api_key = "xxxx" # quandlのAPIKEY
df = quandl.get("WIKI/AAPL")
self.df = df
self.w1 = 1
self.w2 = 0.1
def __getitem__(self, index):
index = index + 5
array_data = [self.df['Adj. Close'][index - 1] * self.w1,
self.df['Adj. Close'][index - 2] * self.w1,
self.df['Adj. Close'][index - 3] * self.w1,
self.df['Adj. Close'][index - 4] * self.w1,
self.df['Adj. Close'][index - 5] * self.w1,
self.df['Adj. Volume'][index - 1] * self.w2,
self.df['Adj. Volume'][index - 2] * self.w2,
self.df['Adj. Volume'][index - 3] * self.w2,
self.df['Adj. Volume'][index - 4] * self.w2,
self.df['Adj. Volume'][index - 5] * self.w2]
target = [self.df['Adj. Close'][index]]
ret_data = torch.Tensor(array_data)
target = torch.Tensor(target)
return ret_data, target
def __len__(self):
return len(self.df) - 5
これで、同じように学習結果が保存できます。
あとは呼び出し元で引数を与えないように修正してやります。
実行
% python test_net.py
な…なんだこれは…
やはり「どう学習させるか」「どう予測させたいか」などを明確にして、
ソースコードを直す必要があるようです。
今回はここでギブアップ。
次回について
学習と予測それぞれの方針を決定、ソースコードに反映させます。
牛歩ですが、がんばります!
別チーム
別チームでもPyTorchの記事を書いています。
Docker環境でPyTorch 〜画像解析〜 #04 セクシー女優学習データ作成編