事の始まりとか
学校の課題でFisherの線形判別分析をやることになった時、Kotlinで実装したいと思ったのでやってみた。
線形判別分析とは
友人が解説してくれているのでそちらを参照するとすごく分かりやすい。
環境
今回はJava環境さえあれば勝手に環境を整えてくれて実行できる方が良いと思ったのでGradle+Kotlin+ND4Jで実装した。
Gradleについて
シンプルに必要なものを書いた。Kotlinもgradleを使ってコンパイル/実行が出来るので便利。
コメントアウトになっている部分はcuda実行用のライブラリです。上のnative-platform
の方をコメントアウトし、cuda側のコメントアウトを削除すればcudaで動く。
version '1.0-SNAPSHOT'
buildscript {
ext.kotlin_version = '1.2.10'
repositories {
mavenCentral()
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
apply plugin: 'kotlin'
apply plugin: 'java'
apply plugin: 'application'
tasks.withType(JavaCompile){ options.encoding = 'UTF-8' }
mainClassName = 'suz.MainKt'
repositories {
mavenCentral()
}
sourceSets {
main.java.srcDirs = ['src/main/kotlin']
}
dependencies {
compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version"
compile "org.nd4j:nd4j-native-platform:0.9.1"
// compile "org.nd4j:nd4j-cuda-8.0-platform:0.9.1"
compile fileTree(dir: 'libs', include: '*.jar')
}
compileKotlin {
kotlinOptions.jvmTarget = "1.8"
}
jar {
from configurations.compile.collect { it.isDirectory() ? it : zipTree(it) }
}
Main
こちらもシンプルに必要なデータの生成を行って呼び出しているだけです。
内容は上項目線形判別分析とはで紹介されている内容と同じです。Python→Java(Kotlin)にしただけです。
Javaのコード規約の関係で若干名前等が長いですが行数的な長さはあまり変わらないと思います。
後はKotlinは数値の多次元配列の定義がJavaよりも煩雑。このあたりは数値計算やる上で障害になるので今後どうにかして欲しいところ。
fun main(args: Array<String>) {
val cov = arrayOf(doubleArrayOf(3.0, 1.0), doubleArrayOf(1.0, 3.0))
var mnd1 = MultivariateNormalDistribution(doubleArrayOf(-5.0, -5.0), cov)
var mnd2 = MultivariateNormalDistribution(doubleArrayOf(5.0, 5.0), cov)
val classData1 = Nd4j.create(mnd1.sample(50))
val classData2 = Nd4j.create(mnd2.sample(50))
var predict = LDA(classData1, classData2)
val result = predict.train()
println("w = [${result!!.getDouble(0, 0)},${result!!.getDouble(1, 0)}]")
}
LDA計算部分
こちらもほぼ紹介そのまま落とし込めています。コードが似ているのはND4JがPythonのNumpyを参考に実装しているのでそこは当たり前といえば当たり前かな...
グラフ表示は課題で使わなかったので端折りました。ライブラリを使えばサクッと実装できると思います。
DataSetクラスを用いると簡単にデータの保存/呼び出しが行えて便利です。
class LDA(val classData1: INDArray, val classData2: INDArray) {
var result: INDArray? = null
val DATA_SIZE = 2
/**
* 学習部分
*/
fun train(): INDArray? {
var mean1 = Nd4j.mean(classData1, 0).reshape(classData1.size(1), 1) //次元0の平均計算
var mean2 = Nd4j.mean(classData2, 0).reshape(classData2.size(1), 1) //次元0の平均計算
//総クラス内の共分散行列
var sw = Nd4j.zeros(DATA_SIZE, DATA_SIZE)
for (index in 0 until classData1.size(0)) {
val shape = classData1.transpose().getColumn(index).reshape(DATA_SIZE, 1)
val sub = shape.sub(mean1)
sw = sw.add(sub.mmul(sub.transpose()))
}
for (index in 0 until classData2.size(0)) {
val shape = classData2.transpose().getColumn(index).reshape(DATA_SIZE, 1)
val sub = shape.sub(mean2)
sw = sw.add(sub.mmul(sub.transpose()))
}
var swInv = InvertMatrix.invert(sw, false)
result = swInv.mmul(mean1.sub(mean2)) //傾きwを求めている
return result
}
fun save() {
val resultFile = File("./linearResult")
DataSet(result, Nd4j.zeros(1, 1)).save(resultFile)
}
fun load() {
val resultFile = File("./linearResult")
var data = DataSet()
data.load(resultFile)
result = data.featureMatrix
}
}
起きた問題とか
ドキュメントが少ない。
ND4jの公式ページにインストールや実行に必要な最低限の事は日本語でも書いてあるのだが、行列計算は英語ページ、更にデータ保存やNumpyデータの読み込みなどはJavadocを見ないと分からないので検索力が求められる。また、知名度が低いのかND4jを用いた記事に関しても英語・日本語で絶望的に少ないので大変かもしれません。
しかし大体の行列演算に関することはND4jで網羅されているので見つけられれば実装できると思います。
まぁ、楽をしたいならPythonってところですかね。
実装したソース
ほぼ全て記事に書きましたが、一応全貌です。
https://github.com/Khromium/torikarasu/tree/qiita