TensorFlow

tensorflowのHParams

tensorflowには機械学習で共通して使われるような高レベルAPIが含まれている.その中の1つにtf.contrib.training.HParamsがある.公式のドキュメントには簡単な使い方が載っているが,jsonファイルへの保存・読み込みを加えたサンプルを作ってみる.

hprams_test.py
import os
import tensorflow as tf
import argparse
import json

def init_hparams():
    """ get default hparams """
    hparams = tf.contrib.training.HParams(
            learning_rate=0.1,
            num_hidden_units=100,
            activations=['relu', 'tanh'],
            test=False
            )
    return hparams

def print_hparams(hparams):
    """ print hparams to stdout """
    values = hparams.values()

    keys = values.keys()
    keys = sorted(keys)

    print("=================================================")
    for key in keys:
        print("{} : {}".format(key, values[key]))
    print("=================================================")

def save_hparams(filename, hparams):
    """ save to json file """
    if not os.path.exists(os.path.dirname(filename)):
        raise ValueError('There is no directory to save {}'.format(filename))

    with open(filename, 'w') as f:
        json.dump(hparams.values(), f, sort_keys=True, separators=(',', ': '), indent=4)

def load_hparams(filename):
    """ load hparams from json file """
    hparams = init_hparams()
    with open(filename, 'r') as f:
        hparams.parse_json(f.read())
    return hparams

# parse user's arguments. ex. python test.py --hparams test=true
parser = argparse.ArgumentParser(description='Train my model.')
parser.add_argument('--hparams', type=str,
                            help='Comma separated list of "name=value" pairs.')
args = parser.parse_args()

# get default params
hparams = init_hparams()

# overwrite by user's arguments
if not args.hparams is None:
    hparams.parse(args.hparams)

print_hparams(hparams)

filename = 'args.json'

# save as json file
save_hparams(filename, hparams)

# load from json file
hparams = load_hparams(filename)

print_hparams(hparams)

実行する場合にはpython hparams_test.py --hparams test=trueなどとすればデフォルトの値に上書きをすることができる.保存されたjsonファイルは以下のようになる:

args.json
{
    "activations": [
        "relu",
        "tanh"
    ],
    "learning_rate": 0.1,
    "num_hidden_units": 100,
    "test": true
}

また素のHParamsでは値の順番が保証されないので,表示や保存をするときにソートをするようにしておくと読みやすくなる.