LoginSignup
4
5

More than 1 year has passed since last update.

pythonで世界経済を予想する

Last updated at Posted at 2017-12-10

はじめに

SVMを使って、世界合計のGDPを年度から予想します。
今回の予想は2つのことを仮定をしています。

・世界経済は成長し続けるものとする。

・今後予想できないような、急激な経済成長は考えないものとする。

データをゲットする

IMFが世界経済の推移をcsvで配布していますのでそれをダウンロードします。
データ

①最新年度を選んで、By Country Groups (aggregated data) and commodity pricesをクリック。

②worldのみチェックを入れる

③Gross domestic product, current prices U.S. dollarsにチェックを入れる

④下のほうの所からエクセルファイルをダウンロード。

⑤データを加工する。(そこそこ、大変でした)

出来れば、加工後を載せたいのですが、再配布していいのかちょっと微妙なのでIMFのサイトからダウンロードしてください。

注意

上記のファイルのGDPの単位は10億米ドルです。
よって、上記のファイルをそのまま使うとGDPの単位が10億米ドルとなります。

ソースコードを書く

さて、pythonとscikit-learnでソースコードを書きたいと思います。

main.py
import sys
from sklearn.svm import LinearSVC
from sklearn import svm
import numpy as np
import copy
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.kernel_ridge import KernelRidge
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.linear_model import Ridge


args = sys.argv

def regression_main():
     # model = svm.SVR(kernel='rbf', C=1e3, gamma=0.1)
      data=[]
      with open(args[1],"r",encoding="utf-8") as f:
           data=[e.replace("\n", "") for e in f.readlines()]
           data=[[float(e) for e in e.split(",")] for e in copy.deepcopy(data) ]
      
      in_data=[e[:-1] for e in data]
      label_data=[e[-1] for e in data]
      
      print(in_data,label_data)
      
      tuned_parameters = [
        {'kernel': ['rbf'], 'gamma': [10**i for i in range(-4,5)], 'C': [10**i for i in range(-3, 8)]}#,
     #   {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}
      ]

      model = GridSearchCV(svm.SVR(), tuned_parameters, cv=5, scoring="mean_squared_error")

      model.fit(in_data,label_data)

      
      plt.scatter([e[0] for e in in_data], [e for e in label_data])
      plt.plot([e[0] for e in in_data], model.predict(in_data))
      plt.show()
      
      while True:
          user_input=input("end?>")
          if user_input=="end":break
          user_in_data=[]
          for i in range(len(in_data[0])):
               t=float(input(">"))
               user_in_data.append(t)
          
          print(model.predict([user_in_data])[0])
      


def main():
     print("1:label\n2:回帰")
     n=int(input(">"))     
     if n==1:
         #label_main()
     if n==2:
         regression_main()

main()

種も仕掛けもない普通のソースコードです。
ただ、パラメーターを色々微調整するのが面倒くさいのでGridSearchCVで自動的に最適化しています。

実行する時は

python main.py 学習ファイル

です

学習データのフォーマット

データの加工の行い方を載せてないので代わりに学習データのフォーマットを載せておきます。

引数1,引数2,引数3...,出力
引数1,引数2,引数3...,出力
...

こんな感じですので今回の場合は

年度,世界合計GDP
年度,世界合計GDP
...

です。

実行する

実行してみたいと思います。

キャプチャ.PNG

青い丸が実際の世界のGDPで青い線がSVMの予想値です。
ほぼ、一次関数ですね...
わざわざ、SVMを使った意味がない気もします。
とにかく、このモデルを使って、予想してみたいと思います。

予想してみる。

では早速、2022と入力して2022年を予想してみたいと思います。

出力は929482.95366億米国ドルでした。
IMFの予想は1032009.94億米国ドルでしたので、
一応予想が出来ました。

まとめ

今回のデータは一次関数が当てはまったので余りSVMの恩恵が得られませんでした。

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5