LoginSignup
0
1

More than 3 years have passed since last update.

python でクラスをJSONに保存する

Last updated at Posted at 2021-04-03

python のクラスをJSONで読み書きする

メンバ変数が非クラスの単純な変数だけで構成されたクラスであれば書き出し・読み込みが可能です
メンバ変数にクラスがあるような場合は素直に pickle を使いましょう

出力処理

vars(self) を json.dumps します

入力処理

self.__dict__ に対してJSONから読み込んだ変数を入力として update を呼び出します

ソース


# numpy を使う時は下記エンコーダを json.dumps に渡す
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

class Hoge:
    def __init__(self):
        self.a = 0
        self.b = 100
        self.c = 1000
    def DumpJson(self):
        return json.dumps(vars(self), cls=NumpyEncoder)

    def LoadJson(self, jsonStr):
        params = json.loads(jsonStr)
        self.__dict__.update(params)

numpyだったメンバ変数はnumpyへ戻す処理が別途必要となりますので忘れないように注意します

使い方(書き出し)

model = Hoge()
jsonStr = model.DumpJson()
# jsonStr をファイル等に保存する

使い方(読み込み)

# jsonStr をファイル等から読み出す
jsonStr = "~~ JSON文字列 ~~"
model = Hoge()
model.LoadJson(jsonStr)

実際の使用例

seeds_dataset.txt ダウンロード先
ダウンロード後に "\t\t" を "\t" に変換しないと正常動作しません

  • Test1 で処理を実行しメンバ変数を設定します
  • Test1で出力されたJSONテキストをコピーしてクリップボードに保存します
  • Test1をコメントアウト
  • Test2のjsonStr変数にJSON文字列を貼り付けてTest2のコメントを解除して実行

import json
import numpy as np
import sys

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

class KMeans:
    def __init__(self, n_clusters=0, n_iterations=100, tol=1e-3):
        self.n_clusters = n_clusters
        self.n_iterations = n_iterations
        self.tol = tol
        self.means = None

    def initRandom(self, X):
        idxs = np.arange(len(X))
        np.random.shuffle(idxs)
        idxs = idxs[:self.n_clusters]
        return X[idxs]

    def eStep(self, X, means):
        predict = []
        for ix in X:
            a = means - ix
            a = a ** 2
            a = np.sum(a, axis=1)
            predict.append(np.argmin(a))
        return np.array(predict)

    def mStep(self, X, predict):
        means = []
        for c in range(self.n_clusters):
            idxs = np.where(predict == c)[0]
            ix = X[idxs]
            means.append(np.mean(ix, axis=0))
        return np.array(means)

    def calcMeansDistance(self, a, b):
        v = a - b
        v = v ** 2
        v = np.sum(v) / self.n_clusters
        return np.sqrt(v)

    def Fit(self, X):
        means = self.initRandom(X)
        for i in range(self.n_iterations):
            predict = self.eStep(X, means)
            newMeans = self.mStep(X, predict)
            distance = self.calcMeansDistance(means, newMeans)
            print("{}/{} {}".format(i, self.n_iterations, distance))
            means = newMeans
            if distance <= self.tol:
                break
        self.means = means
        return means

    def Predict(self, X):
        return self.eStep(X, self.means)

    def DumpJson(self):
        return json.dumps(vars(self), cls=NumpyEncoder)

    def LoadJson(self, jsonStr):
        params = json.loads(jsonStr)
        self.__dict__.update(params)

def Test1():
    data = np.loadtxt("./seeds_dataset.txt", delimiter="\t")
    xdata = data[:,0:7]
    ydata = data[:,7]
    model = KMeans(
        n_clusters=3,
        n_iterations=100,
        tol=1e-3
    )
    model.Fit(xdata)
    predict = model.Predict(xdata)
    print(predict)
    print(model.DumpJson())

def Test2():
    data = np.loadtxt("./seeds_dataset.txt", delimiter="\t")
    xdata = data[:,0:7]
    ydata = data[:,7]

    jsonStr = '{"n_clusters": 3, "n_iterations": 100, "tol": 0.001, "means": [[18.721803278688522, 16.297377049180326, 0.8850868852459014, 6.208934426229506, 3.7226721311475406, 3.603590163934426, 6.0660983606557375], [14.64847222222222, 14.460416666666658, 0.8791666666666667, 5.563777777777778, 3.277902777777778, 2.648933333333333, 5.192319444444446], [11.964415584415585, 13.274805194805198, 0.8522000000000004, 5.229285714285714, 2.8729220779220785, 4.759740259740259, 5.088519480519479]]}'
    model = KMeans()
    model.LoadJson(jsonStr)
    predict = model.Predict(xdata)
    print(predict)
    print(model.DumpJson())

def main():
    #Test1()
    Test2()


if __name__ == "__main__":
    main()

# python my_kmeans.py

以上です

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1