androidでtensorflow

  • 0
    いいね
  • 0
    コメント

    概要

    androidでtensorflowやってみた。
    api17でやってみた。

    写真

    device-2017-05-13-114913.png

    成果物

    https://play.google.com/store/apps/details?id=com.ohisamallc.ohiapp142&hl=ja

    環境

    windows 7
    ADT v22.6
    android 4.2.2 api17

    ビルド手順

    githubから、以下を入手

    graph_label_strings.txt
    mnist_model_graph.pb
    libtensorflow_inference.so
    tensorflow java

    プロジェクトに配置

    graph_label_strings.txt
    mnist_model_graph.pb
    assetsに置く
    libtensorflow_inference.so
    libsに置く
    tensorflow java
    srcに置く

    サンプルコード

    
    package com.ohisamallc.ohiapp142;
    
    import android.content.res.AssetManager;
    import android.util.Log;
    import org.tensorflow.Operation;
    import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.ArrayList;
    import java.util.Comparator;
    import java.util.List;
    import java.util.PriorityQueue;
    import java.util.Vector;
    
    public class TensorFlowImageClassifier implements Classifier {
        private static final String TAG = "TensorFlowImageClassifier";
        private static final int MAX_RESULTS = 3;
        private static final float THRESHOLD = 0.1f;
        private String inputName;
        private String outputName;
        private int inputSize;
        private Vector<String> labels = new Vector<String>();
        private float[] outputs;
        private String[] outputNames;
        private boolean logStats = false;
        private TensorFlowInferenceInterface inferenceInterface;
        private TensorFlowImageClassifier() {
        }
        public static Classifier create(AssetManager assetManager, String modelFilename, String labelFilename, int inputSize, String inputName, String outputName)
                throws IOException {
            TensorFlowImageClassifier c = new TensorFlowImageClassifier();
            c.inputName = inputName;
            c.outputName = outputName;
            String actualFilename = labelFilename.split("file:///android_asset/")[1];
            Log.i(TAG, "Reading labels from: " + actualFilename);
            BufferedReader br = null;
            br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
            String line;
            while ((line = br.readLine()) != null)
            {
                c.labels.add(line);
            }
            br.close();
            c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
            final Operation operation = c.inferenceInterface.graphOperation(outputName);
            final int numClasses = (int) operation.output(0).shape().size(1);
            Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
            c.inputSize = inputSize;
            c.outputNames = new String[] {
                outputName
            };
            c.outputs = new float[numClasses];
            return c;
        }
        @Override
        public List<Recognition> recognizeImage(final float[] pixels) {
            inferenceInterface.feed(inputName, pixels, 784);
            inferenceInterface.run(outputNames, logStats);
            inferenceInterface.fetch(outputName, outputs);
            PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>(3, new Comparator<Recognition>() {
                @Override
                public int compare(Recognition lhs, Recognition rhs) {
                    return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                }
            });
            for (int i = 0; i < outputs.length; ++i)
            {
                if (outputs[i] > THRESHOLD)
                {
                    pq.add(new Recognition("" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
                }
            }
            final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
            int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
            for (int i = 0; i < recognitionsSize; ++i)
            {
                recognitions.add(pq.poll());
            }
            return recognitions;
        }
        @Override
        public void enableStatLogging(boolean debug) {
            //inferenceInterface.enableStatLogging(debug);
        }
        @Override
        public String getStatString() {
            return inferenceInterface.getStatString();
        }
        @Override
        public void close() {
            inferenceInterface.close();
        }
    }