LoginSignup
0
0

Deep Java Library (DJL) と GridDB を用いた時系列データの予測

Posted at

はじめに

時系列データはどこにでもあります。株価や天候パターンから売上高やセンサーデータまで、私たちの生活の様々な場面で重要な役割を果たしています。過去の時系列データに基づいて将来の値を予測できることは、情報に基づいた意思決定を行う上で非常に貴重です。この記事では、Deep Java Library(DJL)とGridDBを使って時系列データを予測する方法を探ります。

時系列データの特徴

時系列データとは、時系列に並んだデータのことで、各データポイントは特定のタイムスタンプに関連付けられています。このデータ形式は金融、ヘルスケア、IoTなど様々な領域で普及しています。効果的な時系列予測を行うには、このような時間的パターンを捉え、理解できるツールやテクニックが必要です。

Deep Java Library(DJL)の紹介

DJLはオープンソースのディープラーニング・ライブラリで、Java開発者にディープラーニングのパワーをもたらすように設計されています。訓練済みモデル、カスタムモデルを訓練するためのツール、TensorFlow、PyTorch、MXNetのような様々なディープラーニングフレームワークとのシームレスな統合を提供します。

時系列予測のためのディープラーニング

ディープラーニングは、複雑な時系列予測問題の解決に目覚ましい成果を示しています。DeepAR(Deep Autoregressive)のようなモデルは、複雑な時間依存関係を捉え、正確な予測を生成することができます。DJLは、このようなモデルを時系列予測タスクに簡単に実装、展開する方法を提供します。

時系列予測にDJLを使用する

時系列予測のためにDJLを使い始めるには、以下のライブラリをプロジェクトに追加する必要があります。プロジェクトが Maven に基づいていると仮定すると、POM ファイルの依存セクションにこれらを追加する必要があります。

        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.23.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.timeseries</groupId>
            <artifactId>timeseries</artifactId>
            <version>0.23.0</version>
        </dependency> 
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <!-- ONNXRuntime -->
        <dependency>
            <groupId>ai.djl.onnxruntime</groupId>
            <artifactId>onnxruntime-engine</artifactId>
            <version>${djl.version}</version>
        </dependency>

次に、環境をセットアップし、いくつかの重要なコンポーネントを理解する必要があります。以下のコード・スニペットを詳しく見てみましょう。

// Import necessary libraries
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.*;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.dataset.TimeFeaturizers;
import ai.djl.timeseries.distribution.DistributionLoss;
import ai.djl.timeseries.distribution.output.DistributionOutput;
// ... other necessary imports ...

public class MonthlyProductionForecast {
    // Constants and configurations
    
    final static String FREQ = "W";
    final static int PREDICTION_LENGTH = 4;
    final static LocalDateTime START_TIME = LocalDateTime.parse("2011-01-29T00:00");
    final static String MODEL_OUTPUT_DIR = "outputs";
     public static void main(String[] args) throws Exception {
        Logger.getAnonymousLogger().info("Starting...");        
        startTraining();
        final Map result = predict();
        for (Map.Entry entry : result.entrySet()) {
            Logger.getAnonymousLogger().info(String.format("metric: %s:\t%.2f", entry.getKey(), entry.getValue()));
        }
    }
}

サンプルコードの完全なプロジェクトはこちらからアクセスできます: GitHub リポジトリ

上記のコードは、時系列予測アプリケーションのエントリーポイントです。このコードでは、構成を設定し、データをロードし、DJLを使用してDeepARモデルを学習します。これがどのように機能するのかを分解してみましょう。

  • 必要なDJLライブラリをインポートし、時系列データの頻度、予測長、開始時間などの定数を定義します。

  • main メソッドは学習プロセスを開始し、GridDBデータベースに接続して時系列データをシードします。GridDBは分散型で拡張性の高いNoSQLデータベースであり、時系列データを効率的に格納することができます。

  • 予測、トレーニングセットアップ、データロードのための様々なメソッドを定義しています。

GridDB について

GridDBは、大量の時系列データを保存・管理するために設計された強力なデータベースシステムです。その高速なデータ取り込みとクエリ機能により、時系列予測アプリケーションに最適な選択肢となっています。

GridDB への時系列データの格納

まず、プロジェクトでGridDBを使用できるようにするために、mavenの依存関係を追加する必要があります。

        <dependency>
            <groupId>com.github.griddb</groupId>
            <artifactId>gridstore-jdbc</artifactId>
            <version>5.3.0</version>
        </dependency>
        <dependency>
            <groupId>com.github</groupId>
            <artifactId>gridstore</artifactId>
            <version>5.3.0</version>
        </dependency>  

次に、やりたいことを実現するために、データベースに時系列データを投入する必要があります。GridDBDatasetクラスの seedDatabase メソッドで、GridDBデータベースに時系列データを投入します。データは2つのcsvファイルから読み込まれ、2つの別々のコンテナに格納されます。以下はそのコードです。

  private static void seedDatabase() throws Exception {
          URL trainingData = Forecaster.class.getClassLoader().getResource("data/weekly_sales_train_validation.csv");
            URL validationData = Forecaster.class.getClassLoader().getResource("data/weekly_sales_train_evaluation.csv");
            String[] nextRecord;
            try ( GridStore store = GridDBDataset.connectToGridDB();  CSVReader csvReader = new CSVReader(new InputStreamReader(trainingData.openStream(), StandardCharsets.UTF_8));  CSVReader csvValidationReader = new CSVReader(new InputStreamReader(validationData.openStream(), StandardCharsets.UTF_8))) {
                store.dropContainer(TRAINING_COLLECTION_NAME);
                store.dropContainer(VALIDATION_COLLECTION_NAME);

                List columnInfoList = new ArrayList<>();

                nextRecord = csvReader.readNext();
                for (int i = 0; i < nextRecord.length; i++) {
                    ColumnInfo columnInfo = new ColumnInfo(nextRecord[i], GSType.STRING);
                    columnInfoList.add(columnInfo);
                }

                ContainerInfo containerInfo = new ContainerInfo();
                containerInfo.setColumnInfoList(columnInfoList);
                containerInfo.setName(TRAINING_COLLECTION_NAME);
                containerInfo.setType(ContainerType.COLLECTION);

                Container container = store.putContainer(TRAINING_COLLECTION_NAME, containerInfo, false);

                while ((nextRecord = csvReader.readNext()) != null) {
                    Row row = container.createRow();
                    for (int i = 0; i < nextRecord.length; i++) {
                        row.setString(i, nextRecord[i]);
                    }
                    container.put(row);
                }

                nextRecord = csvValidationReader.readNext();
                columnInfoList.clear();
                for (int i = 0; i < nextRecord.length; i++) {
                    ColumnInfo columnInfo = new ColumnInfo(nextRecord[i], GSType.STRING);
                    columnInfoList.add(columnInfo);
                }

                containerInfo = new ContainerInfo();
                containerInfo.setName(VALIDATION_COLLECTION_NAME);
                containerInfo.setColumnInfoList(columnInfoList);
                containerInfo.setType(ContainerType.COLLECTION);

                container = store.putContainer(VALIDATION_COLLECTION_NAME, containerInfo, false);
                while ((nextRecord = csvValidationReader.readNext()) != null) {
                    Row row = container.createRow();
                    for (int i = 0; i < nextRecord.length; i++) {
                        String cell = nextRecord[i];
                        row.setString(i, cell);
                    }
                    container.put(row);
                }
            }
    }

DJL と GridDB の統合

DJLとGridDBはシームレスに連携します。GridDBに接続して時系列データにアクセスし、DJLを使用して予測モデルを構築、学習、デプロイします。GridDBDataset クラスは、GridDBデータセットとやり取りするために必要な機能を提供します。

DJLのTimeSeriesDatasetのカスタム実装を作成する必要があることがわかりました。これは、DJLとカスタムデータリポジトリをシームレスに統合するための最も独創的な方法の1つです。その実装がこちらです。

...
public class GridDBDataset extends M5Forecast {

    ...

    public static GridStore connectToGridDB() throws GSException {
        Properties props = new Properties();
        props.setProperty("notificationMember", "127.0.0.1:10001");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        return GridStoreFactory.getInstance().getGridStore(props);
    }

    public static class GridDBBuilder extends M5Forecast.Builder {
    ...

        private File fetchDBDataAndSaveCSV(GridStore store) throws GSException, FileNotFoundException {
           File csvOutputFile = new File(this.getContainerName()+ ".csv");
            try ( GridStore store2 = store) {
                Container container = store2.getContainer(this.getContainerName());

                Query query = container.query("Select *");
                RowSet rowSet = query.fetch();

                int columnCount = rowSet.getSchema().getColumnCount();

                List csv = new LinkedList<>();
                StringBuilder builder = new StringBuilder();

                //Loan column headers
                ContainerInfo cInfo = rowSet.getSchema();
                for (int i = 0; i < cInfo.getColumnCount(); i++) {
                    ColumnInfo columnInfo = rowSet.getSchema().getColumnInfo(i);
                    builder.append(columnInfo.getName());
                    appendComma(builder, i, cInfo.getColumnCount());
                }
                csv.add(builder.toString());

                //Load each row
                while (rowSet.hasNext()) {
                    Row row = rowSet.next();
                    builder = new StringBuilder();
                    for (int i = 0; i < columnCount; i++) {
                        String val = row.getString(i);
                        builder.append(val);
                        appendComma(builder, i, columnCount);
                    }
                    csv.add(builder.toString());
                }
                try ( PrintWriter pw = new PrintWriter(csvOutputFile)) {
                    csv.stream()
                            .forEach(pw::println);
                }
            }
            return csvOutputFile;
        }

        public GridDBBuilder initData() throws GSException, FileNotFoundException {
            this.csvFile = fetchDBDataAndSaveCSV(this.store);
            return this;
        }

        @Override
        public GridDBDataset build() {
            GridDBDataset gridDBDataset = null;
            try {
                gridDBDataset = new GridDBDataset(this);
            } catch (GSException | FileNotFoundException ex) {
                Logger.getLogger(GridDBDataset.class.getName()).log(Level.SEVERE, null, ex);
            }
            return gridDBDataset;
        }

    }
}

高度な時系列予測モデルの構築

我々の予測能力の中核はDeepARモデルにあります。startTraining メソッドでDeepARモデルを作成、訓練、評価します。DJLの使いやすいAPIにより、モデル・アーキテクチャを定義し、時系列データで訓練することが容易になります。

private static void startTraining() throws IOException, TranslateException, Exception {

        DistributionOutput distributionOutput = new NegativeBinomialOutput();

        Model model = null;
        Trainer trainer = null;
        NDManager manager = null;
        try {
            manager = NDManager.newBaseManager();
            model = Model.newInstance("deepar");
            DeepARNetwork trainingNetwork = getDeepARModel(distributionOutput, true);
            model.setBlock(trainingNetwork);

            List trainingTransformation = trainingNetwork.createTrainingTransformation(manager);

            Dataset trainSet = getDataset(Dataset.Usage.TRAIN, trainingNetwork.getContextLength(), trainingTransformation);

            trainer = model.newTrainer(setupTrainingConfig(distributionOutput));
            trainer.setMetrics(new Metrics());

            int historyLength = trainingNetwork.getHistoryLength();
            Shape[] inputShapes = new Shape[9];
            // (N, num_cardinality)
            inputShapes[0] = new Shape(1, 1);
            // (N, num_real) if use_feat_stat_real else (N, 1)
            inputShapes[1] = new Shape(1, 1);
            // (N, history_length, num_time_feat + num_age_feat)
            inputShapes[2] = new Shape(1, historyLength, TimeFeature.timeFeaturesFromFreqStr(FREQ).size() + 1);
            inputShapes[3] = new Shape(1, historyLength);
            inputShapes[4] = new Shape(1, historyLength);
            inputShapes[5] = new Shape(1, historyLength);
            inputShapes[6] = new Shape(1, 1, TimeFeature.timeFeaturesFromFreqStr(FREQ).size() + 1);
            inputShapes[7] = new Shape(1, 1);
            inputShapes[8] = new Shape(1, 1);
            trainer.initialize(inputShapes);
            int epoch = 10;
            EasyTrain.fit(trainer, epoch, trainSet, null);
        } finally {
            if (trainer != null) {
                trainer.close();
            }
            if (model != null) {
                model.close();
            }
            if (manager != null) {
                manager.close();
            }
        }
    }

それでは startTraining メソッドの各ステップを分解してみましょう。

ステップ 1:モデルの分布出力を定義します。この場合、NegativeBinomialOutputに設定します。分布出力は、モデルが予測を生成する方法を指定します。

ステップ 2: getDeepARModel メソッドを使用して、DeepAR 学習ネットワークを作成します。このネットワークは、DeepAR モデルのアーキテクチャを定義します。重要なのは、これがトレーニングネットワークであることを示すために true を渡すことです。

ステップ 3: データセットのトレーニング変換を定義します。これらの変換は入力データに適用され、トレーニングに備えます。データの正規化、特徴エンジニアリングなどが含まれます。

ステップ 4:getDatasetメソッドを使用してトレーニングデータセットを準備します。このデータセットはDeepARモデルの学習に使用されます。このデータセットには、過去のデータとトレーニングの目標値が含まれます。

ステップ 5: モデルをトレーニングするためのトレーナーを作成し、設定します。setupTrainingConfigメソッドは、損失関数、評価子、トレーニングリスナーを含むトレーニング構成を設定します。

ステップ 6:入力形状でトレーナを初期化します。このステップでは,トレーナがモデルに期待される入力形状を知っていることを確認します.inputShapes 配列には、モデルの様々な入力コンポーネントの形状が格納されます。

ステップ 7: 最後に、EasyTrain.fitメソッドを用いてモデルの学習を開始します。トレーニングエポック数、トレーニングデータセット(trainSet)、その他のオプションパラメータを指定します。トレーナはモデルのパラメータを最適化して、定義された損失関数を最小化し、トレーニングデータでの性能を向上させます。

全体として startTraining メソッドはモデルを設定し、データセットを準備し、トレーナを初期化することで、時系列予測のための DeepAR モデルを設定し、訓練します。このステップの組み合わせにより、過去の時系列データに基づいて正確な予測を行うためのモデルの効果的な学習が保証されます。

予測する

学習後、predict メソッドを使って学習したモデルに基づいて予測を行うことができます。このメソッドはモデルのパフォーマンスを評価するために、RMSSE (Root Mean Squared Scaled Error)、MSE (Mean Squared Error)、分位点損失などの様々なメトリクスを計算します。

結論

この記事では、DJLとGridDBを使用して時系列データを予測する方法を探りました。時系列データ、DJL、GridDB の主要概念を紹介し、時系列予測のための DeepAR モデルの構築と学習に関わるコードの詳細な説明を行いました。ディープラーニングのパワーとGridDBの効率性を組み合わせることで、時系列データから貴重な洞察を引き出し、ビジネスや研究のために情報に基づいた意思決定を行うことができます。DJLの使いやすさと柔軟性は、時系列予測の課題に取り組もうとするあらゆるJava開発者にとって貴重なツールとなっています。

結論として、DJLとGridDBの相乗効果により、時系列データの潜在能力を活用し、様々な領域でより良い意思決定を促進する正確な予測を提供することができます。本記事で得た知識により、最先端のディープラーニング技術と堅牢なデータベース・ソリューションを用いた時系列予測の旅に出るための十分な準備が整いました。

DJLとGridDBの融合は、時系列予測の世界に新たな可能性を開きます。この分野をより深く掘り下げることで、データ駆動型洞察の力と、それが金融からヘルスケアまで幅広い業界にどのような革命をもたらすかを発見できるでしょう。探求し続けましょう、学び続け、自信を持って未来を予測し続けましょう。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0