なぜ記事にしたか
意外に、私のまとまりのない記事を読んでくれる人がいたので、地球は1つ→人類皆兄弟→仕事仲間意識が勝手に芽生え、たまに書いています。(あと、自分のためにもなるからです。)
やってみたこと
- LSTMをモデルのレイヤーに使う
- RNA-Seqデータ(801×20531)を使ってがんの遺伝子(5種類)をマルチクラス分類で予測する
つかったもの
- ノートパソコン(一般のもの)
- Optional: NVIDIA-GPU(今回は1050Ti)(AMDはだめ。dl4jがCUDA頼りなので)
環境
- ubuntu 18.04 (javaなのでOSは細かく問わない)
- maven + dl4j関連
- eclipse (2018)
- JDK8
- (もう、こういうテーマに興味ある人はpythonでやってますよね、、)
データ
このデータは、the cancer genome atlas pan-cancer analysis projectにて、イルミナ社のHiSeqという高性能なRNA解析装置を用いて取得されたデータセットからランダムに抽出された一部のデータ。
詳しくは、https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3919969/
RNA-Seqは、リファレンス配列に対してどのくらい遺伝子が発現していそうかを遺伝子ごとに数値化して、横に並べたもの。
リファレンス配列は、遺伝子のパターンがある程度分かっているときに、研究者側で決めておく遺伝子のパターン。
配列のパターンは、個体の性質・属性を説明できる可能性があるため注目されている。
特に、がん細胞のRNAの探索は創薬研究でずっと前から行われてきた。詳しい人によると、人のゲノムが解明されてもまだまだ仕事が山積みらしい。
分類するもの
この研究では以下のがん種が検討されているが、サンプルデータにはこれらの内の5つ(★)が含まれている。
- LUSC Lung Squamous Carcinoma
- READ Rectal Adeno Carcinoma
- GBM Glioblastoma Multiform
- LAML Lymphoblastic Acute Myeloid Leukemia
- HNSC Head and Neck Squamous Carcinoma
- BLCA Bladder Carcinoma
- KIRC Kidney Renal Clear Cell Carcinoma(★)
- UCEC Uterine Cervical and Endometrial Carcinoma
- LUAD Lung Adenocarcinoma(★)
- OV Cvarian Carcinoma
- BRCA Breast Carcinoma(★)
- COAD Colon Adenocarcinoma(★)
- PRAD Prostate Adenocarcinoma(★)
この5つのがんの種類を分類対象とする。
プロジェクトの準備
mavenプロジェクトを作り、任意のクラスファイルを作成する。
CSVの準備
//csvの読み込み
File dataFile = new File("./TCGA-PANCAN-HiSeq-801x20531/data.csv");
File labelFile = new File("./TCGA-PANCAN-HiSeq-801x20531/labels.csv");
データの抽出
int numClasses = 5; //5 classes
int batchSize = 801; //samples total
//先にデータ本体から中身を取得
RecordReader reader = new CSVRecordReader(1,',');//skip header
try {
reader.initialize(new FileSplit(dataFile));
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
double[][] dataObj = new double[batchSize][];
int itr = 0;
while(reader.hasNext()) {
List<Writable> row = reader.next();
double scalers[] = new double[row.size()-1];
for(int i = 0; i < row.size()-1; i++) {
if(i == 0) {//skip subject
continue;
}
double scaler = Double.parseDouble(new ConvertToString().map(row.get(i)).toString());
scalers[i] = scaler;
}
dataObj[itr] = scalers;
itr++;
}
System.out.println("Data samples "+ +dataObj.length);//801
//ラベルの読み込み
//マルチラベル用に変換も行う
//label
try {
reader = new CSVRecordReader(1,',');//skip header
reader.initialize(new FileSplit(labelFile));
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
double[][] labels = new double[batchSize][];
itr = 0;
while(reader.hasNext()) {
List<Writable> row = reader.next();
double scalers[] = null;
for(int i = 0; i < row.size(); i++) {
if(i == 0) {//skip subject
continue;
}
// Class
if(i == 1) {
String classname = new ConvertToString().map(row.get(i)).toString();
switch(classname) {
case "BRCA":
scalers = new double[]{1,0,0,0,0};
break;
case "PRAD":
scalers = new double[]{0,1,0,0,0};
break;
case "LUAD":
scalers = new double[]{0,0,1,0,0};
break;
case "KIRC":
scalers = new double[]{0,0,0,1,0};
break;
case "COAD":
scalers = new double[]{0,0,0,0,1};
break;
default:
break;
}
labels[itr] = scalers;
itr++;
}
}
}
System.out.println("LABEL : "+labels.length);//801
データをDataSetオブジェクへするために、一旦INDArrayに
//DataSetをつくる
INDArray dataArray = Nd4j.create(dataObj,'c');
System.out.println(dataArray.shapeInfoToString());
INDArray labelArray = Nd4j.create(labels,'c');
System.out.println(labelArray.shapeInfoToString());
//Rank: 2,Offset: 0
// Order: c Shape: [801,20531], stride: [20531,1]
//Rank: 2,Offset: 0
// Order: c Shape: [801,5], stride: [5,1]
DataSetへ
DataSet dataset = new DataSet(dataArray, labelArray);
SplitTestAndTrain sp = dataset.splitTestAndTrain(600, new Random(42L));//600 train, 201 test
DataSet train = sp.getTrain();
DataSet test = sp.getTest();
System.out.println(train.labelCounts());
System.out.println(test.labelCounts());
//{0=220.0, 1=105.0, 2=104.0, 3=109.0, 4=62.0}
//{0=80.0, 1=31.0, 2=37.0, 3=37.0, 4=16.0}
モデルの構築・トレーニング・評価
//MODEL TRAIN AND EVALUATION
int numInput = 20531;
int numOutput = numClasses;
int hiddenNode = 500;//非力
int numEpochs = 50;
MultiLayerConfiguration LSTMConf = new NeuralNetConfiguration.Builder()
.seed(123)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.001))
.list()
.layer(0,new LSTM.Builder()
.nIn(numInput)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(1,new LSTM.Builder()
.nIn(hiddenNode)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(2,new LSTM.Builder()
.nIn(hiddenNode)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(3,new RnnOutputLayer.Builder()
.nIn(hiddenNode)
.nOut(numOutput)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunction.MCXENT)//multi class cross entropy
.build())
.pretrain(false)
.backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(LSTMConf);
model.init();
System.out.println("TRAIN START...");
for(int i=0;i<numEpochs;i++) {
model.fit(train);
}
System.out.println("EVALUATION START...");
Evaluation eval = new Evaluation(5);
for(DataSet row :test.asList()) {
INDArray testdata = row.getFeatures();
INDArray pred = model.output(testdata);
eval.eval(row.getLabels(), pred);
}
System.out.println(eval.stats());
評価結果の出力
TRAIN START...
EVALUATION START...
Predictions labeled as 0 classified by model as 0: 80 times
Predictions labeled as 1 classified by model as 1: 31 times
Predictions labeled as 2 classified by model as 0: 3 times
Predictions labeled as 2 classified by model as 2: 34 times
Predictions labeled as 3 classified by model as 2: 1 times
Predictions labeled as 3 classified by model as 3: 36 times
Predictions labeled as 4 classified by model as 2: 4 times
Predictions labeled as 4 classified by model as 4: 12 times
==========================Scores========================================
# of classes: 5
Accuracy: 0.9602
Precision: 0.9671
Recall: 0.9284
F1 Score: 0.9440
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
========================================================================
感想
特徴量が数百を超えるときは、ロジスティック回帰やSVM(リニア)が用いられるが、LSTMを使うのも良いのかもしれないと思った。
学習の時間も思ったほどかからない。
やってみただけだけれど、試しといてよかった。
参考資料
- Java Deep Learning Projects: Implement 10 real-world deep learning applications using Deeplearning4j and open source APIs
添付(pom.xml)
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.vis</groupId>
<artifactId>CancerGenomeTest</artifactId>
<version>0.0.1-SNAPSHOT</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.8</java.version>
<nd4j.version>1.0.0-alpha</nd4j.version>
<dl4j.version>1.0.0-alpha</dl4j.version>
<datavec.version>1.0.0-alpha</datavec.version>
<arbiter.version>1.0.0-alpha</arbiter.version>
<logback.version>1.2.3</logback.version>
<dl4j.spark.version>1.0.0-alpha_spark_2</dl4j.spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId>
<version>${dl4j.spark.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-ui_2.11</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-codec</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.3.5</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>2.11.0</version>
</dependency>
</dependencies>
</project>