はじめに
Processingそのものの開発スピードが早すぎてapiが変化し、公開されているパッケージのサンプルが動かなくなっているケースをよく見かける。そんなケースを見かけたら、微修正したサンプルコードをさらすことにする。今日は、PSVMのサンプルコード。その他、折角なので、クラス化しておく。
ちなみにPSVMとは、サポートベクターマシンをprocessingで使えるようにするjavaパッケージです。末尾のリンクを参照してください。
PSVMのサンプルコード
サンプルコードには以下のパッケージを使用しています。
- logback-classic-0.9.30.jar
- logback-core-0.9.30.jar
- slf4j-api-1.6.3.jar
上記のの3つはログ用に使用しています。以下のサイトから取得できますが、面倒なら、コードの中のlogger.error()という関数をSystem.err.println()に置き換えれば使う必要はありません。
さて、本題。
SVMComponent.java
import psvm.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;
import java.lang.Class;
import java.lang.reflect.Field;
import processing.data.Table;
import processing.data.TableRow;
import processing.core.PApplet;
import java.io.File;
class SVMComponent {
protected static Logger logger = LoggerFactory.getLogger(SVMComponent.class);
protected PApplet applet = null;
protected String labelColumnName = null;
// protected List rows = null;
public int [] labels = null;
public float [][] trainingPoints = null;
public SVM model = null;
public SVMProblem problem = null;
public SVMComponent (PApplet applet) {
this.applet = applet;
model = new SVM(applet);
problem = new SVMProblem();
}
public boolean save (String modelPath) {
try {
this.model.saveModel(modelPath);
return true;
} catch (Exception e) {
logger.error("Error: {}", e);
return false;
}
}
public boolean load (String modelPath, int columnNumber) {
try {
this.model.loadModel(modelPath, columnNumber);
return true;
} catch (Exception e) {
logger.error("Error: {}", e);
return false;
}
}
public boolean importDataFromFile (String filepath, int labelIndex) {
try {
Table data = new Table(new File(filepath));
int featureNum = data.getColumnCount() - 1;
this.trainingPoints = new float[data.getRowCount()][featureNum];
this.labels = new int[data.getRowCount()];
float [] maximums = new float [featureNum]; // for normalization
float [] minimums = new float [featureNum]; // for normalization
for (int i = 0; i < featureNum; i++) minimums[i] = Float.MAX_VALUE;
int i = 0;
for (TableRow row : data.rows()) {
float[] p = new float[featureNum];
boolean aboveLabel = false;
for (int j = 0; j < data.getColumnCount(); j++) {
if (j == labelIndex) {
labels[i] = row.getInt(j);
aboveLabel = true;
continue;
}
int index = aboveLabel ? j - 1: j;
float value = row.getFloat(index);
//logger.error("index: {}", index);
if (value > maximums[index]) maximums[index] = value;
if (value < minimums[index]) minimums[index] = value;
p[index] = value;
}
this.trainingPoints[i] = p;
i++;
}
// normalization
for (i = 0; i < data.getRowCount(); i++) {
for (int j = 0; j < data.getColumnCount(); j++) {
double range = maximums[j] - minimums[j];
logger.debug("data range is {}", range);
this.trainingPoints[i][j] -= minimums[j];
this.trainingPoints[i][j] /= maximums[j] - minimums[j];
}
}
this.problem.setNumFeatures(featureNum);
return true;
} catch (Exception e) {
logger.error ("Error: {}", e);
return false;
}
}
public boolean importDataFromFile (String filepath, int labelIndex, double dataRange) {
try {
Table data = new Table(new File(filepath));
int featureNum = data.getColumnCount() - 1;
this.trainingPoints = new float[data.getRowCount()][featureNum];
this.labels = new int[data.getRowCount()];
logger.debug("Data details:");
logger.debug(" Num of columns: {}", featureNum);
logger.debug(" Num of training points: {}", data.getRowCount());
float [] maximums = new float [featureNum]; // for normalization
float [] minimums = new float [featureNum]; // for normalization
for (int i = 0; i < featureNum; i++) minimums[i] = Float.MAX_VALUE;
int i = 0;
for (TableRow row : data.rows()) {
float[] p = new float[featureNum];
boolean aboveLabel = false;
for (int j = 0; j < data.getColumnCount(); j++) {
if (j == labelIndex) {
labels[i] = row.getInt(j);
aboveLabel = true;
continue;
}
int index = aboveLabel ? j - 1: j;
float value = row.getFloat(index);
if (value > maximums[index]) maximums[index] = value;
if (value < minimums[index]) minimums[index] = value;
p[index] = value;
}
this.trainingPoints[i] = p;
i++;
}
// normalization
for (i = 0; i < data.getRowCount(); i++) {
for (int j = 0; j < data.getColumnCount() - 1; j++) {
this.trainingPoints[i][j] /= dataRange;
}
}
this.problem.setNumFeatures(featureNum);
return true;
} catch (Exception e) {
logger.error ("Error: {}", e);
return false;
}
}
public boolean train () {
try {
problem.setSampleData(this.labels, this.trainingPoints);
this.model.train(problem);
return true;
} catch (Exception e) {
logger.error("Error: {}", e.getMessage());
return false;
}
}
public int test (double [] testSet) {
try {
return (int)this.model.test(testSet);
} catch (Exception e) {
logger.error("Error: {}", e);
return -1;
}
}
}
processingのサンプルコードは以下、
SVMApp.pde
PGraphics modelDisplay = null;
boolean showModel = false;
SVMComponent svm = null;
void setup () {
size(500,500);
// displaying the model is very slow, so we'll
// do it in a PGraphics so we only have to do it once
modelDisplay = createGraphics(500,500);
svm = new SVMComponent(this);
svm.importDataFromFile(sketchPath + "/points.csv", 2, 500);
File f = new File (sketchPath + "/trained.txt");
if (f.exists()) {
svm.load(sketchPath + "/trained.txt", 2);
} else {
svm.train();
svm.save(sketchPath + "/trained.txt");
}
drawModel();
}
void draw(){
// show our model background if we want
if (showModel) {
image(modelDisplay, 0, 0);
} else {
background(255);
}
stroke(255);
// show all of the training points
// in the right color based on their labels
for(int i = 0; i < svm.trainingPoints.length; i++){
//println("svm.label: " + Integer.toString(svm.labels[i]));
if(svm.labels[i] == 1){
fill(255,0,0);
} else if(svm.labels[i] == 2){
fill(0,255,0);
} else if(svm.labels[i] == 3){
fill(0,0,255);
}
//println (String.format("(%.2f, %.2f)", svm.trainingPoints[i][0] * 500, svm.trainingPoints[i][1] * 500));
ellipse(
svm.trainingPoints[i][0] * 500,
svm.trainingPoints[i][1] * 500,
5,
5
);
}
}
void keyPressed(){
if(key == ' '){
showModel = !showModel;
}
// save out the model file for use
// in future classification
if(key == 's'){
svm.save(sketchPath + "/model.txt");
}
}
// on mouse click, for any given point
// test it against the model and print the result set
void mousePressed () {
double[] p = new double[2];
p[0] = (double)mouseX / width;
p[1] = (double)mouseY / height;
println((int)svm.test(p));
}
// this function colors in each pixel of the sketch
// based on what result the model predicts for that x-y value
// it saves the results in a PGraphics object
// so that it can be displayed everytime beneath the data
void drawModel(){
// start drawing into the PGraphics instead of the sketch
modelDisplay.beginDraw();
// for each row
for(int x = 0; x < width; x++){
// and each column
for(int y = 0; y < height; y++){
// make a 2-element array with the x and y values
double[] testPoint = new double[2];
testPoint[0] = (double)x / width;
testPoint[1] = (double)y / height;
// pass it to the model for testing
double d = svm.test(testPoint);
// based on the result, draw a red, green, or blue dot
if((int)d == 1){
modelDisplay.stroke(255,0,0);
} else if ((int)d == 2){
modelDisplay.stroke(0, 255 ,0);
} else if ((int)d == 3){
modelDisplay.stroke(0, 0, 255);
}
// which will fill up the entire area of the sketch
modelDisplay.point(x,y);
}
}
// we're done with the PGraphics
modelDisplay.endDraw();
}
processing.core.Tableのコンストラクタなどの仕様が変わっているので注意。
今後の課題
サンプルに習って、GUIも用意しようと思う。