LoginSignup
5
6

googleのtimesfmのモデルで日経平均を予測してみた

Posted at

経緯説明

最近(2024/05/17時点) googleやopenAIなど各社が色々なAIを発表して、色々面白いもの発表しているところほとんどLLM、個人的にLLMはそこまでやる気がないけど、別の面白いモデルがあればいいなぁと思ってhugging face に潜入捜査したら、なんとmodels のランキング1がgoogleのtimesfmのモデルだ!metaの有名なLLM(Llama3)を超えている!気になって見たら時系列のモデルだそうです。

image.png

公式資料
googleのresearch リンク:https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/
github リンク: https://github.com/google-research/timesfm?tab=readme-ov-file
hugging face リンク:https://huggingface.co/google/timesfm-1.0-200m

手順説明

それで株価も時系列と関係あると思い出して、ほぉ~これはもしかして前にやりたかった株価予測できるでは!GPTさんに聞いてみたら、できるとの返事が来て、昔自分が途中まで書いてたコードを取り出して、公式のサンプルコードを合成して、GPT-4oの協力によってなんかうまくいってるように見えたので、投稿しました。

実際精度とか本当に投資して報酬出来るかどうかはわからないです。
もしこのコード参考して改造などやって投資失敗したら自己責任です。

まずログインして、アクセス権をもらいます。すぐ使える。
image.png

そして、右上のアカウント設定でtokenを新規メモ帳などでコピペしてもいいです。後で使います。
今回は珍しくgoogle colabを利用しています。python version <3.10 が条件だそうです。試してみたがエラーでPythonをインストールするのがめんどくさいからgoogle colabで実行しています。

以下google colab で実行

! git clone https://github.com/google-research/timesfm.git
%cd timesfm
!pip install -e .
!pip install utilsforecast
!pip install transformers accelerate bitsandbytes

huggingface Tokenの登録
下の****************** はコピーしたtokenです

!huggingface-cli login --token ******************

main code

公式のgithub 説明してたが少しだけ言います
TimesFm()の中のbackend は gpu cpu tpu の三つの選択ができます。

そして、yfinanceで予測したい株価を検索してcodelist に入れ替えればいいです。
yfinanceのリンク: https://finance.yahoo.com/

import datetime
import yfinance as yf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from timesfm import TimesFm

# 日経平均株価のデータ取得
start = datetime.date(2022, 1, 1)
end = datetime.date.today()
codelist = ["^N225"]
data2 = yf.download(codelist, start=start, end=end)

# データの前処理
data2 = data2['Adj Close'].dropna()  # Adjusted Close価格を使用し、欠損値を除去
if data2.empty:
    raise ValueError("データが空です。期間を変更して再度試してください。")

context_len = 512  # コンテキスト長の設定
horizon_len = 128  # 予測する期間の長さの設定

if len(data2) < context_len:
    raise ValueError(f"データの長さがコンテキスト長({context_len})より短いです。")

context_data = data2[-context_len:]  # 最新の512日分のデータをコンテキストとして使用

# TimesFMモデルの初期化と読み込み
tfm = TimesFm(
    context_len=context_len,
    horizon_len=horizon_len,
    input_patch_len=32,
    output_patch_len=128,
    num_layers=20,
    model_dims=1280,
    backend='gpu',
)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

# データの準備
forecast_input = [context_data.values]
frequency_input = [0]  # データの頻度を設定(0は高頻度のデータ)

# 予測の実行
point_forecast, experimental_quantile_forecast = tfm.forecast(
    forecast_input,
    freq=frequency_input,
)

# 予測結果の表示
forecast_dates = pd.date_range(start=data2.index[-1] + pd.Timedelta(days=1), periods=horizon_len, freq='B')
forecast_series = pd.Series(point_forecast[0], index=forecast_dates)

plt.figure(figsize=(14, 7))
plt.plot(data2.index, data2.values, label="Actual Prices")
plt.plot(forecast_series.index, forecast_series.values, label="Forecasted Prices")
plt.xlabel("Date")
plt.ylabel("Price")
plt.legend()
plt.show()

日経平均株価予測結果の画像
image.png

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6