LoginSignup
29
43

More than 5 years have passed since last update.

Anaconda環境でPyTorch 〜株価予想〜 #02 基礎知識・学習編

Last updated at Posted at 2018-07-27

はじめに

株式会社クリエイスモトキです。

前回は、ざっくり環境構築と試しに株価を取得しました。

Anaconda環境でPyTorch 〜株価予想〜 #01 環境構築編
Anaconda環境でPyTorch 〜株価予想〜 #02 基礎知識・学習編(今回)
Anaconda環境でPyTorch 〜株価予想〜 #03 予測編
Anaconda環境でPyTorch 〜株価予想〜 #04 予測(リベンジ)編
Anaconda環境でPyTorch 〜株価予想〜 #05 表示編

今回は、読み漁っている中でぶち当たった壁を壊す為に調べたことを並べていきます。
(機械学習のソースを書く以前に、人が書いたコードが読めなくて進められなかった)

別チーム

別チームでもPyTorchの記事を書いています。
Docker環境でPyTorch 〜画像解析〜 #02 モデル訓練&保存編

基礎知識

知らないことだらけでも、概要くらいは理解しないとぐぐれないし進めないので、
今回は、機械学習の基礎知識を固めていきます。
それぞれ浅くほっていきます。

キーワード 概要 ソースで登場する箇所
RNN(Recurrent Neural Network) 再帰型ニューラルネットワーク -
LSTM(Long short-term memory) 短期記憶のモデル的な。RNNの拡張・RNNの一種とされる
→今回使うのはこっち
torch.nn.LSTM
Keras, Chainer, TensorFlow PyTorchと同じように機械学習のためのライブラリ
(取り組む前はTensorFlowしか知りませんでした)
-
autograd 自動微分 torch.autograd
backward 後ろ向き計算(後ろめたくない) loss.backward()
Adam(Adaptive moment estimation) 勾配法の一種
Eveもいるようです
torch.optim

こういったもろもろを用意してくれるのがPytorchというライブラリなんですね。

機械学習

今回はゼロベースではなく、参考にしつつ進めています。
https://github.com/komi1230/Predict-Stock-Price

環境自体は、前回構築したAnaconda3でやっていきます。

variable()問題

PyTorch0.4でtorch.autograd.variable()torch.Tensorとマージされました。
それにより、listのラッピングを返却しなくなりました。以下のように修正する必要があります。

get_data.py
import quandl
import numpy as np
import torch
from torch.autograd import Variable


class Get_data:
    def __init__(self, n_prev, data_code):
        self.n_prev = n_prev
        # 今回は、検証用に短い期間を設定して実行します。
        self.data = quandl.get(data_code, start_date="2017/07/01", end_date="2018/07/25", returns="numpy")

    def get_data(self, today):
        tmpX = []
        tmpY = []
        print(today)
        for k in range(self.n_prev):
            tmpX.append(self.data[today - self.n_prev + k][1])
        tmpY.append(self.data[today][1])
        retx = Variable(torch.from_numpy(np.array(tmpX))).float().view(1, 1, self.n_prev)
        rety = Variable(torch.from_numpy(np.array(tmpY))).float().view(1, 1, 1)
        return retx, rety

    def get_raw_data(self):
        return self.data

あとは、appleの株価を取得するソースだったので、agent.pyに反映。

agent.py
import torch.nn as nn
import torch.optim as optim
from network import Network
from get_data import Get_data

data_code = "WIKI/AAPL"  # Quandl data code

# 後略
console
% python agent.py

(多分)学習しているとこまで行きました。
間違ってる等ありましたら、やんわりご指摘お願いします!

次回

次回は、学習データの確認を行います。

Creaithメンバー

この記事の著者:モトキ曽宮

その他メンバー:志村上田

参考

29
43
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
43