0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PSVMのサンプルコード@Processing2

Last updated at Posted at 2013-10-20

はじめに

Processingそのものの開発スピードが早すぎてapiが変化し、公開されているパッケージのサンプルが動かなくなっているケースをよく見かける。そんなケースを見かけたら、微修正したサンプルコードをさらすことにする。今日は、PSVMのサンプルコード。その他、折角なので、クラス化しておく。

ちなみにPSVMとは、サポートベクターマシンをprocessingで使えるようにするjavaパッケージです。末尾のリンクを参照してください。

PSVMのサンプルコード

サンプルコードには以下のパッケージを使用しています。

  • logback-classic-0.9.30.jar
  • logback-core-0.9.30.jar
  • slf4j-api-1.6.3.jar

上記のの3つはログ用に使用しています。以下のサイトから取得できますが、面倒なら、コードの中のlogger.error()という関数をSystem.err.println()に置き換えれば使う必要はありません。

さて、本題。

SVMComponent.java
import psvm.*;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;

import java.lang.Class;
import java.lang.reflect.Field;

import processing.data.Table;
import processing.data.TableRow;
import processing.core.PApplet;

import java.io.File;

class SVMComponent {
	protected static Logger logger = LoggerFactory.getLogger(SVMComponent.class);

	protected PApplet applet = null;
	protected String labelColumnName = null;
	// protected List rows = null;

	public int [] labels = null;
	public float [][] trainingPoints = null;

	public SVM model = null;
	public SVMProblem problem = null;

	public SVMComponent (PApplet applet) {
		this.applet = applet;
		model = new SVM(applet);
		problem = new SVMProblem();
	}

	public boolean save (String modelPath) {
		try {
			this.model.saveModel(modelPath);
			return true;
		} catch (Exception e) {
			logger.error("Error: {}", e);
			return false;
		}
	}

	public boolean load (String modelPath, int columnNumber) {
		try {
			this.model.loadModel(modelPath, columnNumber);
			return true;
		} catch (Exception e) {
			logger.error("Error: {}", e);
			return false;
		}
	}


	public boolean importDataFromFile (String filepath, int labelIndex) {
		try {
			Table data = new Table(new File(filepath));
  			int featureNum = data.getColumnCount() - 1;
			this.trainingPoints = new float[data.getRowCount()][featureNum];
			this.labels = new int[data.getRowCount()];
			
			float [] maximums = new float [featureNum]; // for normalization
			float [] minimums = new float [featureNum]; // for normalization
			for (int i = 0; i < featureNum; i++) minimums[i] = Float.MAX_VALUE;

			int i = 0;
			for (TableRow row : data.rows()) {
				float[] p = new float[featureNum];
				boolean aboveLabel = false;
				for (int j = 0; j < data.getColumnCount(); j++) {
					if (j == labelIndex) {
						labels[i] = row.getInt(j);
						aboveLabel = true;
						continue;
					}
					int index = aboveLabel ? j - 1: j;
					float value = row.getFloat(index);
					//logger.error("index: {}", index);
					
					if (value > maximums[index]) maximums[index] = value;
					if (value < minimums[index]) minimums[index] = value;
					p[index] = value;

				}

			    this.trainingPoints[i] = p;
			    i++;
			}

			// normalization
			for (i = 0; i < data.getRowCount(); i++) {
				for (int j = 0; j < data.getColumnCount(); j++) {
					double range = maximums[j] - minimums[j];
					logger.debug("data range is {}", range);

					this.trainingPoints[i][j] -= minimums[j];
					this.trainingPoints[i][j] /= maximums[j] - minimums[j];
				}
			}			

			this.problem.setNumFeatures(featureNum);
			return true;
		} catch (Exception e) {
			logger.error ("Error: {}", e);
			return false;
		}
	} 

	public boolean importDataFromFile (String filepath, int labelIndex, double dataRange) {
		try {
			Table data = new Table(new File(filepath));
  			int featureNum = data.getColumnCount() - 1;
			this.trainingPoints = new float[data.getRowCount()][featureNum];
			this.labels = new int[data.getRowCount()];
			
			logger.debug("Data details:");
			logger.debug("  Num of columns: {}", featureNum);
			logger.debug("  Num of training points: {}", data.getRowCount());

			float [] maximums = new float [featureNum]; // for normalization
			float [] minimums = new float [featureNum]; // for normalization
			for (int i = 0; i < featureNum; i++) minimums[i] = Float.MAX_VALUE;

			int i = 0;
			for (TableRow row : data.rows()) {
				float[] p = new float[featureNum];
				boolean aboveLabel = false;
				for (int j = 0; j < data.getColumnCount(); j++) {
					if (j == labelIndex) {
						labels[i] = row.getInt(j);
						aboveLabel = true;
						continue;
					}
					int index = aboveLabel ? j - 1: j;
					float value = row.getFloat(index);
					if (value > maximums[index]) maximums[index] = value;
					if (value < minimums[index]) minimums[index] = value;
					p[index] = value;

				}

			    this.trainingPoints[i] = p;
			    i++;
			}

			// normalization
			for (i = 0; i < data.getRowCount(); i++) {
				for (int j = 0; j < data.getColumnCount() - 1; j++) {
					this.trainingPoints[i][j] /= dataRange;
				}
			}



			this.problem.setNumFeatures(featureNum);
			return true;
		} catch (Exception e) {
			logger.error ("Error: {}", e);
			return false;
		}
	} 

	public boolean train () {
		try {
			problem.setSampleData(this.labels, this.trainingPoints);
			this.model.train(problem);
			return true;
		} catch (Exception e) {
			logger.error("Error: {}", e.getMessage());
			return false;
		}
	}

	public int test (double [] testSet) {
		try {
			return (int)this.model.test(testSet);
		} catch (Exception e) {
			logger.error("Error: {}", e);
			return -1;
		}
	}

}

processingのサンプルコードは以下、

SVMApp.pde
PGraphics modelDisplay = null;
boolean showModel = false;
SVMComponent svm = null;

void setup () {
	size(500,500);
  
	// displaying the model is very slow, so we'll
	// do it in a PGraphics so we only have to do it once
	modelDisplay = createGraphics(500,500);	

	svm = new SVMComponent(this);
	svm.importDataFromFile(sketchPath + "/points.csv", 2, 500);

	File f = new File (sketchPath + "/trained.txt");
	if (f.exists()) {
		svm.load(sketchPath + "/trained.txt", 2);
	} else {
		svm.train();
		svm.save(sketchPath + "/trained.txt");
	}

	drawModel();
}


void draw(){
	// show our model background if we want
	if (showModel) {
	    image(modelDisplay, 0, 0);
	} else {
	    background(255);
	}
	  
	stroke(255);
	  
	// show all of the training points
	// in the right color based on their labels
	for(int i = 0; i < svm.trainingPoints.length; i++){
	    //println("svm.label: " + Integer.toString(svm.labels[i]));
            if(svm.labels[i] == 1){
	    	fill(255,0,0);
	    } else if(svm.labels[i] == 2){
	    	fill(0,255,0);
	    } else if(svm.labels[i] == 3){
	    	fill(0,0,255);
	    }
	    
	    //println (String.format("(%.2f, %.2f)", svm.trainingPoints[i][0] * 500, svm.trainingPoints[i][1] * 500));
        ellipse(
	    	svm.trainingPoints[i][0] * 500, 
	    	svm.trainingPoints[i][1] * 500, 
	    	5, 
	    	5
	    	);
	}
}

void keyPressed(){
	if(key == ' '){
	    showModel = !showModel;
	}
	// save out the model file for use
	// in future classification
	if(key == 's'){
	    svm.save(sketchPath + "/model.txt");
	}
}

// on mouse click, for any given point
// test it against the model and print the result set
void mousePressed () {
	double[] p = new double[2];
	p[0] = (double)mouseX / width;
	p[1] = (double)mouseY / height;
	println((int)svm.test(p));
}

// this function colors in each pixel of the sketch
// based on what result the model predicts for that x-y value
// it saves the results in a PGraphics object
// so that it can be displayed everytime beneath the data
void drawModel(){
	// start drawing into the PGraphics instead of the sketch
	modelDisplay.beginDraw();
	// for each row
	for(int x = 0; x < width; x++){
	    // and each column
	    for(int y = 0; y < height; y++){
	      
		    // make a 2-element array with the x and y values
		    double[] testPoint = new double[2];
		    testPoint[0] = (double)x / width;
		    testPoint[1] = (double)y / height;
		      
		    // pass it to the model for testing
		    double d = svm.test(testPoint);
		      
		    // based on the result, draw a red, green, or blue dot
		    if((int)d == 1){
		        modelDisplay.stroke(255,0,0);
		    } else if ((int)d == 2){
		        modelDisplay.stroke(0, 255 ,0);
		    } else if ((int)d == 3){
		        modelDisplay.stroke(0, 0, 255);
		    }
		    
		    // which will fill up the entire area of the sketch
		    modelDisplay.point(x,y);
	  
	    }
	}
	// we're done with the PGraphics
	modelDisplay.endDraw();
}

processing.core.Tableのコンストラクタなどの仕様が変わっているので注意。

今後の課題

サンプルに習って、GUIも用意しようと思う。

参考URL

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?