3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Hyperbolic GCN のライブラリー解説(2)

Posted at

はじめに

本記事はHyperbolic GCN のライブラリー解説(1)の続きです。もしご覧になっていない方はよろしければ御覧ください。

この記事は

  • Pytorch などの機械学習ライブラリー未経験の人を想定して
  • 実データの解析までできるようになることを目標に
  • 実装にあたって自分が詰まった経験を踏まえ

複数回に分けて投稿します。もし内容に誤りがありましたら教えていただけますと幸いです。

また、今後は

  • ライブラリーの構成(前回)
  • コードの詳細(今回含め数回)
  • データの準備
  • hyperbolicGCN の理論的背景

の順に更新します。応用をしたい方にとっては分野に共通するライブラリー構成を初めにざっと掴み、コードを読みながら理論を知る方が意義が大きいのではないか、と狙った記事構成になっています。

今回扱うライブラリーは、グラフデータを扱う深層学習 (Graph Neural Network, GNN) のライブラリーの一つ、Hyperbolic Graph Convolutional Network in Pytorch です。GNNの種類の一つ、GCN に共通する基本的な構造から始めて、最後の回で理論の詳細に触れようと思います。

今回ご紹介するライブラリーは以下にあります。
https://github.com/HazyResearch/hgcn

今回の記事はプログラムを実行するときの main 関数が含まれる train.py を順に解説していきます。

簡単な解説→コード(コード内のコメントに実行内容を記載)

という流れで進めていきます。

処理の流れ

まず、このファイルの処理の流れは以下の通りです。

スクリーンショット 2020-03-15 21.26.30.png

では、詳しくコードを読み込んでいきます。

main関数

コマンドラインから各種設定を読み込む

 
いきなりコードの一番下の部分から始まりますが、main 関数なのでご容赦ください。
main関数ではコマンドラインと config.py に入力された各種変数を取得し、train 関数に渡します。
args.変数名 でその引数の値にアクセスできます。参考
たとえば print(args.seed) とすると、乱数生成器を指定するシード値を出力できます。
各変数のデフォルトの値を見たいとき/設定したいときは config.py を見ます/書き換えます。

train.py
if __name__ == '__main__'
    # 以下で各種パラメーターを取得して args に渡す
    args = parser.parse_args()
    train(args)

たとえば活性化関数もコマンドラインから指定できます。
詳細は以下の README 詳しく書いてあり、すぐに実行することができます。
https://github.com/HazyResearch/hgcn

train 関数

結果の再現性を得るために乱数生成器を指定する

学習の精度は、以下の方法で評価されます。

  1. データを訓練データとテストデータにランダムに分けます。
  2. 訓練データを用いて学習を行い、テストデータで学習の精度を評価します。

データを訓練データとテストデータに分割するときの乱数生成器を一定にすることでその他の条件(使用するマシンなど)を揃えたときの再現性を得やすくなります。(ただし、Pytorchのバージョン違いやデバイスの違いなどにより完璧に再現されることは保証されていません。)

Pytorchのtorch.manual_seed(シード値) は乱数生成器 (RGN) を一定にすることを保証します。
参考:pytorch reproducibility
本ライブラリー作成者はシード値を ’1234’ にして計算しています。

train.py
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if int(args.double_precision):
        torch.set_default_dtype(torch.float64)
    if int(args.cuda) >= 0:
        torch.cuda.manual_seed(args.seed)

GPU or CPU を指定する

args.cuda の値が 0 以上であれば GPU を指定します。それ以外は CPU を指定します。デフォルトは CPU になっていました。
2行目はちょっと内容がずれますが、過学習を防ぐために、早期停止する設定をします。
早期停止 (early-stopping) についてはいかが参考になります。
http://torch.classcat.com/2016/03/27/deeplearning-tutorial-gettingstarted/

train.py
    # GPU or CPU 指定
    args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu'
    # 早期停止を行う条件を指定
    args.patience = args.epochs if not args.patience else  int(args.patience)

ログ出力の設定をする

logの取得参考

上の参考記事の中にログの出力にもいくつかのレベルがあると書いてありますが、ここでは ”INFO” を指定しています。

コマンドラインから --save=1 と指定すると指定した場合に結果の保存先のパスを生成します。各種設定や結果が記録されたログをこの保存先に書き込みます。結果のファイルを後で処理してグラフなどにまとめるので、自分にとって処理しやすいフォーマットやパスに変えるのも良いと思います。
重複になるので、これ以降ログ出力に関するコードは適宜省きます。

train.py
    # ログを取るレベルを指定
    logging.getLogger().setLevel(logging.INFO)
    # 保存する場合パスを指定
    if args.save:
        if not args.save_dir:
            dt = datetime.datetime.now()
            # フォーマットを指定する文法だそうです。簡単で使いやすそう。
            date = f"{dt.year}_{dt.month}_{dt.day}"
            models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date)
            save_dir = get_dir_name(models_dir)
        else:
            save_dir = args.save_dir
        logging.basicConfig(level=logging.INFO,
                            handlers=[
                                logging.FileHandler(os.path.join(save_dir, 'log.txt')),
                                logging.StreamHandler()
                            ])
    # ログを出力するときは以下のように書く
    logging.info(f'Using: {args.device}')
    logging.info("Using seed {}.".format(args.seed))

データを読み込む

utils フォルダの data_utils 関数を用いて data フォルダにあるデータを読み込みます。
※データ型などの utilis のコードの解説は後ほどアップします。
ノードの数や特徴の数もここで読み込みます。

train.py
    # Load data
    print(os.environ)
    # set_env.sh がうまく行かなかったときはここで環境変数を指定しても(対症療法だけど)
    data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset))
    args.n_nodes, args.feat_dim = data['features'].shape

タスクを指定する

タスクの種類は2つあります。
ノードの分類 'nc' と結合の予測 'lp' です。

ノード分類のタスクを行う場合は、モデルは'NCModel'を指定し、分類するクラスの数はデータのラベルの数になります。
結合予測のタスクを行う場合は、モデルは'LPModel'を指定し、分類するクラスの数はデータのラベルの数になります。
※nc/lp 以外のタスクを指定することができるようになっているみたい。その時は RECMODEL を指定するとあるが、なんのことかまだわかっていないです。あとで見てみます。

train.py
    if args.task == 'nc':
        Model = NCModel
        args.n_classes = int(data['labels'].max() + 1)
    else:
        args.nb_false_edges = len(data['train_edges_false'])
        args.nb_edges = len(data['train_edges'])
        if args.task == 'lp':
            Model = LPModel
        else:
            Model = RECModel
            # No validation for reconstruction task
            args.eval_freq = args.epochs + 1

model と optimizer を指定する

機械学習の文脈における optimizer とは、ロス関数の値をより小さくするための手法です。
参考

lr_scheduler は学習率を適宜変化させる設定を行います。以下が参考になります。
参考
lr_reduce_freq でこの幅のとり方を指定しているようです。

最後にGPUで計算する場合はそれに合うようにデータを整形します。

train.py
    # 学習率を変化させる頻度を指定しない場合はエポック数から決定
    if not args.lr_reduce_freq:
        args.lr_reduce_freq = args.epochs
    # Model and optimizer
    model = Model(args)
    optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=int(args.lr_reduce_freq),
        gamma=float(args.gamma)
    )
    tot_params = sum([np.prod(p.size()) for p in model.parameters()])
    # GPU を利用する場合
    if args.cuda is not None and int(args.cuda) >= 0 :
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
        model = model.to(args.device)
        for x, val in data.items():
            if torch.is_tensor(data[x]):
                data[x] = data[x].to(args.device)

繰り返し学習を行う

各訓練データを学習させる繰り返し回数を epoch 数といいます。
参考
毎回の学習率を各繰り返しごとに更新して、学習を行い結果を出力します。

train.py
    # Train model
    # 計算時間の測定を開始
    t_total = time.time()
    counter = 0
    # 結果を格納する変数は初期化
    best_val_metrics = model.init_metric_dict()
    best_test_metrics = None
    best_emb = None
    # 繰り返し学習する
    for epoch in range(args.epochs):
        t = time.time()
        # 訓練モードを開始
        model.train()
        optimizer.zero_grad()
        # ユークリッド空間上の点を双曲空間に埋め込み各層の計算を進め、ロス関数の値を計算する
        embeddings = model.encode(data['features'], data['adj_train_norm'])
        train_metrics = model.compute_metrics(embeddings, data, 'train')
        # ロス関数の値に基づき重みを調整
        train_metrics['loss'].backward()
        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
        # 学習率を変更
        lr_scheduler.step()
        # 適当な間隔で結果を出力     
        if (epoch + 1) % args.log_freq == 0:
            logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1),
                                   'lr: {}'.format(lr_scheduler.get_lr()[0]),
                                   format_metrics(train_metrics, 'train'),
                                   'time: {:.4f}s'.format(time.time() - t)
                                   ]))
        # 適当な間隔でテストを行う     
        if (epoch + 1) % args.eval_freq == 0:
            # 訓練モードからテストモードに切り替える
            model.eval()
            embeddings = model.encode(data['features'], data['adj_train_norm'])
            # テストデータにモデルを当てはめて学習の精度評価を行う
            val_metrics = model.compute_metrics(embeddings, data, 'val')
            if (epoch + 1) % args.log_freq == 0:
                logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val')]))
            if model.has_improved(best_val_metrics, val_metrics):
                # 結果が改善されたら結果を出力して保存する
                best_test_metrics = model.compute_metrics(embeddings, data, 'test')
                best_emb = embeddings.cpu()
                if args.save:
                    np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy())
                best_val_metrics = val_metrics
                counter = 0
            else:
                counter += 1
                # 過学習が起こることを避けるために早期停止の判断をする
                if counter == args.patience and epoch > args.min_epochs:
                    logging.info("Early stopping")
                    break
    # この時点でテストが行われていない場合ここでテストを行う
    # 説明の流れを考えて "Optimization Finished!" の行よりちょっと上に持ってきました。
    if not best_test_metrics:
        model.eval()
        best_emb = model.encode(data['features'], data['adj_train_norm'])
        best_test_metrics = model.compute_metrics(best_emb, data, 'test')


結果を出力する

結果を出力します。
繰り返しになりますがこのあたりの設定を変えれば、ご自身でグラフにまとめやすくしたりできます。

train.py
    logging.info("Optimization Finished!")
    logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    logging.info(" ".join(["Val set results:", format_metrics(best_val_metrics, 'val')]))
    logging.info(" ".join(["Test set results:", format_metrics(best_test_metrics, 'test')]))
    # 結果を保存したいときはここを見る
    if args.save:
        np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.cpu().detach().numpy())
        if hasattr(model.encoder, 'att_adj'):
            filename = os.path.join(save_dir, args.dataset + '_att_adj.p')
            pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb'))
            print('Dumped attention adj: ' + filename)

        json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w'))
        torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
        logging.info(f"Saved model in {save_dir}")

終わりに

さて今回は、このライブラリーの main 関数を含んでいる train.py の内容を解説しました。train.py という名前ですがテストもしてくれていますね。実行するコマンドもデータも揃っているので、動かすことができます。このライブラリーは説明がとても丁寧で本当に助かります…!親切!ありがたい…。

次回は、データの中身がどうなっているか詳しく見ていこうと思います。引き続きよろしくお願いいたします!

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?