LoginSignup
0

More than 1 year has passed since last update.

JavaによるK-Nearest Neighborアルゴリズム

Posted at

はじめに

K-nearest neighbors (KNN) は分類や回帰のタスクを実行するために用いられる教師あり機械学習アルゴリズムです。

KNNは、テストデータと学習点の距離を求めることで、テストデータに対して正しいクラスを予測します。このアルゴリズムでは、テストデータに最も近い点をK個選択します。そして、テストデータがK個のクラスに分類される確率を計算します。そして、最も高い確率を持つクラスが選択されます。

KNN アルゴリズムは、学習データから識別関数を学習するのではなく、学習データセットを記憶するため、遅延学習アルゴリズム と呼ばれます。

データセット

未知のデータに対して予測を行うために、顧客データセットを使用します。使用するデータセットは、以下のように、顧客の年齢、収入、購入した商品などを示しています。
image.png
年齢58歳、収入51000円のお客様に対するオススメ商品を見つけようと思います。
まず距離を計算し、次にkの値に応じて最も近いk人の近傍を取得することができるようにします。
デフォルトではkの値は1ですが、インスタンス生成時にkの値を渡すことができます。kが1の場合は1個の製品、つまり1件の最近傍が表示され,2の場合は2個の製品、つまり2件の最近傍が表示されます。

K-最近傍アルゴリズムの Java での実装

上記のデータセットを用いて、JavaでKNNアルゴリズムを実装してみます。データセットはcustomers.csvという名前のCSVファイルに保存されています。

csvファイルからデータを読み込み、GridDBにロードします。その後、GridDBからデータを取り出し、アルゴリズムを使って分析します。

パッケージのインポート

まず、使用する必要のあるパッケージをインポートしましょう。

import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

データをGridDBに書き込む

CSVファイルからGridDBコンテナにデータを移動させたいと思います。まず、コンテナ・スキーマを静的クラスとして作成しましょう。

public static class Customers{

        @RowKey int customer;
        int age;
        Double income;
        String purchased_product;
}

上記のクラスは、コンテナや4つのカラムを持つSQLテーブルに似ています。

それでは、GridDBへの接続を確立してみましょう。GridDBをインストールしたときの指定情報に基づいてPropertiesのインスタンスを作成します。以下のコードを使用します。

        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

GridDB のインストール環境に合わせて、上記の内容を変更します。
ここでは、Customersコンテナを使用するため、これを選択します。

Collection<String, Customers> coll = store.putCollection("col01", Customers.class);

コンテナCustomersのインスタンスが作成され、coll という名前が付けられました。このインスタンスを使って、コンテナを参照することにします。

データをGridDBに格納する

以下のJavaコードでcustomers.csvファイルからデータを読み込んで、GridDBに格納することができます。

                File file1 = new File("customers.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String customer = dataList[0];
                        String age = dataList[1];
                        String income = dataList[2];
                        String purchased_product = dataList[3];
                        
                        Customers customers = new Customers();
                        customers.customer = Integer.parseInt(customer);
                        customers.age = Integer.parseInt(age);
                        customers.income = Double.parseDouble(income);
                        customers.purchased_product = purchased_product;
                        coll.append(customers);
                 }

顧客に関するデータでcustomersオブジェクトを作成しました。このオブジェクトは、GridDB コンテナに追加されます。

GridDB からデータを取得する

いよいよGridDBコンテナからデータを取り出します。以下のコードを使用します。

                Query<customers> query = coll.query("select *");
                RowSet<customers> rs = query.fetch(false);
                RowSet res = query.fetch();

select *文は、データベースコンテナからすべてのデータを問い合わせるのに役立ちます。

分類器の構築

いよいよ KNN アルゴリズムと読み込まれたデータを使って分類器を構築する時が来ました。そのために必要なライブラリをインポートしましょう。

import java.io.IOException;
import java.util.Enumeration;
import java.text.DecimalFormat;

import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.classifiers.lazy.IBk;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.converters.ArffLoader;

それでは、モデルを構築し、その統計情報をプリントアウトしてみましょう。

res.setClassIndex(res.numAttributes() - 1);
        Classifier cls = new IBk(1);        
        cls.buildClassifier(res);
    
        System.out.println(cls);
       
        Evaluation evaluation = new Evaluation(res);
        evaluation.evaluateModel(cls, res);
        
        System.out.println(evaluation.toSummaryString());
        System.out.println(evaluation.toClassDetailsString());
        System.out.println(evaluation.toMatrixString());

IBk インスタンスを作成する際にkの値を指定しました。IBk インスタンスは整数の引数を取ります。1 という値を渡すと、1 つの最近傍を見つけることができます。2 を渡すと、2 つの最近傍を計算します。引数を渡さず、デフォルトのコンストラクタで呼び出した場合は、1近傍を計算します。今回の場合、1という値を渡しているので、顧客の最近傍を1件予測することになります。

コードのコンパイルと実行

まず、gsadmユーザでログインします。作成した.javaファイルを GridDB のbinフォルダに移動します。移動先は以下の通りです。
/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin
次に、Linux端末で以下のコマンドを実行し、gridstore.jarファイルのパスを設定します。

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

次に、以下のコマンドを実行して.javaファイルをコンパイルします。

javac KNNeighbor.java

以下のコマンドを実行して生成された .class ファイルを実行します。

java KNNeighbor

KNNモデルは、顧客の最近傍を1つ返します。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0