How to retrain an Image Classifier in tensorflow
You might probably have heard about image classification in Machine Learning. It required a lot of more training data and computing power to process. Moreover, It also required lots of knowledges and skill to work with it. Fortunately, Google has created a very useful open-source Machine Learning library called tensorflow to help us getting thing done in Machine Learning a lot more easier and faster than before. In this article I will walk you through the most easiest part to train your own image classifier using tensorflow. It doesn't even require you to have a high skill in Machine Learning, but you might probably need to understand the basic of Machine Learning first.
Setup the environment
- I recommend using docker with tensorflow, since docker has tensorflow image available with all the dependencies we need for tensorflow. I think It might save us lot of time and effort checking with the installation step by step. So what are you waiting for? Let's go ahead download and install Docker in your machine.
- After you install docker, open up your terminal and type:
docker run hello-world
to check whether docker is successfully installed on your local machine.

If you can see the result as shown in the picture above, that's mean docker is successfully installed on your machine.
- Now lets pull the tensorflow docker image using this following command:
sudo docker pull b.gcr.io/tensorflow/tensorflow:latest-gpu
It required your computer to have internet connections.
- Now check if the tensorflow image is successfully download:
docker images
- Now let's run the tensorflow image bash:
docker run -it tensorflow/tensorflow:1.1.0 bash
Seem like you are ready for the training model.
Training the data
- Let create a new folder in the tensorflow docker image:
mkdir tensorflow
cd tensorflow
- Download a sample flower photos provided by Google for the dataset:
curl -LO http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz
- Download the retrain python code from Github:
curl -LO https://github.com/tensorflow/hub/raw/r0.1/examples/image_retraining/retrain.py
- Start training the model with the flower dataset you just download:
python -m scripts.retrain \
--bottleneck_dir=tf_files/bottlenecks \
--how_many_training_steps=500 \
--model_dir=tf_files/models/ \
--summaries_dir=tf_files/training_summaries/"${ARCHITECTURE}" \
--output_graph=tf_files/retrained_graph.pb \
--output_labels=tf_files/retrained_labels.txt \
--architecture="${ARCHITECTURE}" \
--image_dir=tf_files/flower_photos
After this command is executed, It will take sometime to train the data. It took me about 30 mins to complete on my computer. If it take you longer time, Don't freak. Just do something useful. Cool!!
The training script will write data to the following file:
- tf_files/retrained_graph.pb : contain version of the selected network with a final layer retrained on your categories.
- tf_files/retrained_labels.txt : which is a text file containing labels.
- Right after finishing the training you now can start doing the classification. Download the label_image example code provided by Google from Github to start testing with your trained model.
curl -LO https://github.com/tensorflow/tensorflow/raw/master/tensorflow/examples/label_image/label_image.py
- To run the label_image follow this command:
python label_image.py \
--graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt \
--input_layer=Placeholder \
--output_layer=final_result \
--image={LOCATION OF YOUR TEST IMAGE ON YOUR COMPUTER}
After running the script, you should see a list of flower labels.
daisy (score = 0.99071)
sunflowers (score = 0.00595)
dandelion (score = 0.00252)
roses (score = 0.00049)
tulips (score = 0.00032)
It come with the score of what type the flower is.
With your trained model you can use it anywhere you want (Example: Android, IOS...)