23
25

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Anaconda環境でPyTorch 〜株価予想〜 #04 予測(リベンジ)編

Last updated at Posted at 2018-10-05

はじめに

株式会社クリエイスモトキです。
前回、pandasでグラフを表示しました。

Anaconda環境でPyTorch 〜株価予想〜 #01 環境構築編
Anaconda環境でPyTorch 〜株価予想〜 #02 基礎知識・学習編
Anaconda環境でPyTorch 〜株価予想〜 #03 予測編
Anaconda環境でPyTorch 〜株価予想〜 #04 予測(リベンジ)編 (今回)
Anaconda環境でPyTorch 〜株価予想〜 #05 表示編

やったこと

とりあえず分からなくてもいいから動かそうということで新たにソースをお借りしました。
https://github.com/melissa135/mlp_stock
文字コードが中国語のものだったり、スペースタブが混じっているので修正してから始めます。

このソースではCSVから取得しています。しかもpandasで取り込んでおりquandlに変えても同じ様に扱えます。

sz_train.csv
date,open,high,low,close,volume,close_change,volume_change
"2005-01-04,二",3051.24,3051.24,3016.26,3025.42,435050970,0,0
...

学習

前回同様に、Appleの株価を拾ってくるようにsample_set.pyを修正します

Sample_set.py
import torch
import torch.utils.data as data
import quandl


class SampleSet(data.Dataset):
    def __init__(self):
        quandl.ApiConfig.api_key = "xxxx" # quandlのAPIKEY
        df = quandl.get("WIKI/AAPL")
        self.df = df
        self.w1 = 1
        self.w2 = 0.1

    def __getitem__(self, index):
        index = index + 5

        array_data = [self.df['Adj. Close'][index - 1] * self.w1,
                      self.df['Adj. Close'][index - 2] * self.w1,
                      self.df['Adj. Close'][index - 3] * self.w1,
                      self.df['Adj. Close'][index - 4] * self.w1,
                      self.df['Adj. Close'][index - 5] * self.w1,
                      self.df['Adj. Volume'][index - 1] * self.w2,
                      self.df['Adj. Volume'][index - 2] * self.w2,
                      self.df['Adj. Volume'][index - 3] * self.w2,
                      self.df['Adj. Volume'][index - 4] * self.w2,
                      self.df['Adj. Volume'][index - 5] * self.w2]

        target = [self.df['Adj. Close'][index]]

        ret_data = torch.Tensor(array_data)
        target = torch.Tensor(target)
        return ret_data, target

    def __len__(self):
        return len(self.df) - 5

これで、同じように学習結果が保存できます。
あとは呼び出し元で引数を与えないように修正してやります。

実行

terminal
% python test_net.py

image.png
な…なんだこれは…
やはり「どう学習させるか」「どう予測させたいか」などを明確にして、
ソースコードを直す必要があるようです。
今回はここでギブアップ。

次回について

学習と予測それぞれの方針を決定、ソースコードに反映させます。
牛歩ですが、がんばります!

別チーム

別チームでもPyTorchの記事を書いています。
Docker環境でPyTorch 〜画像解析〜 #04 セクシー女優学習データ作成編

Creaithメンバー

この記事の著者:曽宮モトキ
その他メンバー:志村上田

23
25
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
23
25

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?