4
8

More than 1 year has passed since last update.

Python機械学習/pickleファイルを使ったオブジェクトの保存

Last updated at Posted at 2021-10-10

Introduction

Pythonで機械学習をする際のデータやモデルの保存にはpickleファイルを使用するのが非常に便利なので,その使い方についてメモを残しておく.

データの準備

今回はScikit-learnのガンデータを使用.

from sklearn.datasets import load_breast_cancer
import pandas as pd

cancer = load_breast_cancer()
data_feature = pd.DataFrame(cancer.data, columns=cancer.feature_names)
data_target = pd.DataFrame(cancer.target)

csvファイルでデータ管理

まずは一般的なCSVファイルでのデータ読み書きについて記す.

csvファイルで保存

data_feature.to_csv('data_feature.csv')
data_target.to_csv('data_target.csv')

csvファイルの読込

feature_csv = pd.read_csv('data_feature.csv', index_col=0)
target_csv = pd.read_csv('data_target.csv', index_col=0)

pickleファイルでデータ管理

では本題のpickleファイルでのデータ読み書きについて記す.

pickleファイルで保存

data_feature.to_pickle('data_feature.pkl')
data_target.to_pickle('data_target.pkl')

pickleファイルの読込

feature_pkl = pd.read_pickle('data_feature.pkl')
target_pkl = pd.read_pickle('data_target.pkl')

pickleファイルで保存することで,ヘッダーやインデックスの有無など気にすることなく,保存時のデータをそのまま読み込めるので非常に便利.

機械学習モデルをpickleファイルで管理

生成した機械学習モデルもpickleで保存できるので,その方法について記す.

とりあえずロジスティック回帰でモデル生成.

from sklearn.linear_model import LogisticRegression

model = LogisticRegression()
model.fit(feature_pkl, target_pkl)

print(model.score(feature_pkl, target_pkl))
Out:
0.945518453427065

モデルの保存

import pickle
pickle.dump(model, open('model.pkl', 'wb'))

モデルの読込

model_pkl = pickle.load(open('model.pkl', 'rb'))
print(model_pkl.score(feature_pkl, target_pkl))
Out:
0.945518453427065

Conclusion

今回使用した関数.

  • DF.to_pickle():データフレームをpickleファイルで保存
  • DF.read_pickle():pickleファイルで保存されたデータフレームの読込
  • pickle.dump():機械学習モデルをpickleファイルで保存
  • pickle.load():pickleファイルで保存された機械学習モデルの読込

pythonによる機械学習が捗ると思うので,pickleファイルを活用されたし.

Code

# データの準備
from sklearn.datasets import load_breast_cancer
import pandas as pd

cancer = load_breast_cancer()
data_feature = pd.DataFrame(cancer.data, columns=cancer.feature_names)
data_target = pd.DataFrame(cancer.target)


# csvファイルでデータ管理
## データの保存
data_feature.to_csv('data_feature.csv')
data_target.to_csv('data_target.csv')
## データの読込
feature_csv = pd.read_csv('data_feature.csv', index_col=0)
target_csv = pd.read_csv('data_target.csv', index_col=0)


# pickleファイルでデータ管理
## データの保存
data_feature.to_pickle('data_feature.pkl')
data_target.to_pickle('data_target.pkl')
## データの読込
feature_pkl = pd.read_pickle('data_feature.pkl')
target_pkl = pd.read_pickle('data_target.pkl')


# 機械学習モデルをpickleファイルで管理
## モデル生成
from sklearn.linear_model import LogisticRegression

model = LogisticRegression()
model.fit(feature_pkl, target_pkl)

print(model.score(feature_pkl, target_pkl))
## データ保存
import pickle
pickle.dump(model, open('model.pkl', 'wb'))
## データの読込
model_pkl = pickle.load(open('model.pkl', 'rb'))

print(model_pkl.score(feature_pkl, target_pkl))
4
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
8