LoginSignup
7
2

More than 3 years have passed since last update.

Tensorflowでの学習をWebsocketで遠隔地から監視できるシステムを作る

Last updated at Posted at 2020-12-21

この記事は 福島高専 Advent Calendar 2020 22日目の記事です。

当方初執筆のため、気になる点などあればコメントよろしくお願いします。

はじめに

Tensorflow での学習を研究室のPCで回すことがよくあるのですが、進捗状況の確認のためにいちいち学校へ行くのがとても面倒なので、自宅でも進捗状況を確認できるシステムを作りました。

やりたいこと

  • 遠隔地からの学習状況の確認
  • 損失関数、評価関数のグラフの描画

実現方法

各 epoch ごとの損失関数や評価関数の値などの諸々の結果を「学習機」から送ってもらい、「監視機」に送信することで実現させます。
学習状況をリアルタイムで監視したいため、Websocketを用いて情報のやり取りを行います。

環境

  • Python 3.8.6
  • Node.js v15.3.0

実装

思ったよりコードが長くなってしまったので、掲載分のコードは適当なところで端折ってあります。
Githubのリポジトリ にソースコードがあるので、そちらを参照してください。

サーバー側

Node.jsを利用します。Websocket サーバーを作成するため、wsパッケージをnpmでインストールしてあげます。

npm install ws

Websocketでは、文字列データを受信/送信することができるため、JSON文字列を送り、受信側でデコードすることで簡単に辞書データを送ることができます。
以下のような構造のJSONデータをやり取りすることにより、送られてきたデータによって処理を分岐させます。

{
  type: "xxx", //データの種類
  data: {
     ... // 諸々のデータ
  }
}
server/main.js
const ws = require("ws");

// データの種類を定数で列挙したファイルをrequire
const messageType = require("./messageType").messageType

//サーバオブジェクト作成
const server = new ws.Server({ port: 8765 })

server.on("connection", (ws) => {
    ws.on("message", (data) => {

        console.log(data)

        // 送られてきたデータをパース
        const jData = JSON.parse(data)

        switch (jData.type) {

            // 接続時
            case messageType.sessionStart:
                ...
                break;

            // 学習タスクに関するメッセージ受信時
            case messageType.trainInfo:
                ...
                break;
    });
});

クライアント側

学習機、監視機どちらもPythonで実装していきます。
学習機でのWebsocket通信にwebsockets、監視機でのGUI作成にPyQt5を使います。

pip install websockets PyQt5

学習機側

各エポックの損失関数・評価関数の値をサーバーに送信するためのコールバックを作成していきます。
Tensorflow のモデルで使用するコールバックは、tf.keras.callbacks.Callbackクラスを継承することで自作できます。
今回は各エポック終了時と学習終了時に値を送信したいので、on_epoch_endon_train_endメソッドに送信部分を書いていきます。

client/ws_callBack.py

import tensorflow as tf
import websockets
import json
from .wsConst import messageType
import re


def makeWSData(dataType: str, data: dict) -> dict:
    return {"type": dataType, "data": data}


class wsConnector(tf.keras.callbacks.Callback):
    def __init__(self, URI, loop, name, details, trainerID, result_regex: dict):
        super().__init__()
        self.loop = loop
        self.trainerID = trainerID
        self.URI = URI

        self.taskID = self.getTaskID(name, details)
        self.result_regex = {k: re.compile(p) for k, p in result_regex.items()}

    def getTaskID(self, name, details):
        trainData = details.copy()
        trainData["layer"] = name
        packet = makeWSData(
            dataType=messageType["trainInfo"],
            data={
                "id": self.trainerID,
                "name": name,
                "type": messageType["train"]["start"],
                "data": trainData
            }
        )

        recvPacket = self.loop.run_until_complete(
                         self.send_and_recv(json.dumps(packet))
                     )
        recvDict = json.loads(recvPacket)

        return recvDict["data"]["taskID"]

    def on_epoch_end(self, epoch, logs=None):
        result = {k: {} for k in self.result_regex.keys()}
        for k, v in logs.items():
            for result_class, result_regex in self.result_regex.items():
                if result_regex.search(k) is not None:
                    result[result_class][k] = float(v)

        packet = makeWSData(
            dataType=messageType["trainInfo"],
            data={
                "id": self.trainerID,
                "type": messageType["train"]["update"],
                "data": {
                    "id": self.taskID, 
                    "epoch": epoch + 1,
                    "result": result
                }
            }
        )
        self.loop.run_until_complete(self.send(json.dumps(packet)))

    def on_train_end(self, logs=None):
        if logs is None:
            result = None
        else:
            result = {k: {} for k in self.result_regex.keys()}
            for k, v in logs.items():
                for result_class, result_regex in self.result_regex.items():
                    if result_regex.search(k) is not None:
                        result[result_class][k] = float(v)

        packet = makeWSData(
            dataType=messageType["trainInfo"],
            data={
                "id": self.trainerID,
                "type": messageType["train"]["end"],
                "data": {
                    "id": self.taskID, 
                    "status": messageType["train"]["success"],
                    "result": result
                }
            }
        )
        self.loop.run_until_complete(self.send(json.dumps(packet)))

    async def send(self, message):
        async with websockets.connect(self.URI) as ws:
            await ws.send(message)

    async def send_and_recv(self, message):
        async with websockets.connect(self.URI) as ws:
            await ws.send(message)
            return await ws.recv()

コールバックを作成したら、モデルがコールバックを呼び出してくれるようmodel.fit時に指定してあげます。

train.py

wsCallBack = wsConnector(
    URI="ws://localhost:8765", loop=loop, name=dataName,
    details=trainParams, trainerID=trainerID, 
    result_regex={"loss": ".*loss.*", "accuracy": ".*accuracy.*"}
)

model.fit(
    trainData,
    epochs=epoch,
    callbacks=[
        wsCallback,
    ],
    validation_data=testData,
)

受信した値をコンソールに出力する Websocket テストサーバを作成し、起動した状態で学習タスクを走らせると、きちんと損失関数・評価関数の値が送信されているのが確認できます。

client/test_server.py
import websockets
import json
import asyncio

async def server(ws, path):
    print(json.loads(await ws.recv()))

loop = asyncio.get_event_loop()
loop.run_until_complete(websockets.serve(server, "localhost", 8765))
loop.run_forever()
テストサーバー出力
{'type': 'TRAIN_INFO', 'data': {'id': 'hoge', 'type': 'UPDATE', 'data': {'id': 'Hoge', 'epoch': 1, 'result': {'loss': {'loss': 6.121824748860965, 'val_loss': 5.398087776068485}, 'accuracy': {'sparse_categorical_accuracy_softmax': 0.08984068036079407, 'val_sparse_categorical_accuracy_softmax': 0.08866003900766373}}}}}
{'type': 'TRAIN_INFO', 'data': {'id': 'hoge', 'type': 'UPDATE', 'data': {'id': 'Hoge', 'epoch': 2, 'result': {'loss': {'loss': 5.307668804372631, 'val_loss': 5.248192136937922}, 'accuracy': {'sparse_categorical_accuracy_softmax': 0.09052243083715439, 'val_sparse_categorical_accuracy_softmax': 0.08866003900766373}}}}}
...

監視機側

PyQt5でGUIを作っていきます。こんな感じで実装しました。
image.png

また、タスク一覧に表示されたタスクをダブルクリックすると、損失関数・評価関数のグラフが表示されるようにしました。
image.png

あとがき

12月に入るまで書く内容が決まらなかったり、「リモートデスクトップ使えば全部解決するんじゃ…」などと考えてモチベーションがだだ下がりしていたのですが、なんとか記事投稿まで持ってこれたので安心しています。

突貫作業だったこともあり、記事の焦点がブレブレなので次回執筆する際にはしっかり考えて書こうと思います。
いろいろ端折ってしまったため、分かりづらいなと感じた箇所は適宜修正・加筆していきます。

参考文献

Pythonの非同期通信(asyncioモジュール)入門を書きました

Githubリポジトリ

7
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
2