0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

環境

  • python 3.7.7
  • CUDA 10.0.130
  • gcc 7.3.0

環境構築

git clone https://github.com/tensorflow/tpu.git
sudo apt-get install -y python-tk
pip install tensorflow-gpu==1.15
pip install --user Cython matplotlib opencv-python-headless pyyaml Pillow
pip install 'git+https://github.com/cocodataset/cocoapi#egg=pycocotools&subdirectory=PythonAPI'

学習済みモデルで推論

  1. 任意のモデルをダウンロード
    https://github.com/tensorflow/tpu/blob/master/models/official/detection/MODEL_ZOO.md

  2. 推論を実行

  • coco_label_map.csvの例
1:person
2:bicycle
3:car

category_id : category の形になっている
クラスを変更したい場合、上記の形式にそってcsvファイルを作成する

  • 実行コード

    python ~/tpu/models/official/detection/inference.py \
      --model="retinanet" \
      --image_size=640\
      --checkpoint_path="./detection_retinanet_50/model.  ckpt" \
      --label_map_file="./retinanet/tpu/models/official/  detection/datasets/coco_label_map.csv" \
      --image_file_pattern="path/to/input/file" \
      --output_html="path/to/output/file" \
      --max_boxes_to_draw=10 \
      --min_score_threshold=0.05
    
  • html形式で出力される。

  • 公式gitでは画像での推論のみ行うことができる。

3. オリジナルデータで学習

  1. 学習済みデータをダウンロード
    https://github.com/tensorflow/tpu/blob/master/models/official/detection/MODEL_ZOO.md
  2. 入力データを作成
  • 入力はtfrecorfd形式を想定。

  • coco形式でデータセットを準備(画像 + * .json)

  • 以下のプログラムでtfrecord形式に変換

    #!/bin/bash
    


TRAIN_IMAGE_DIR="path/to/train/images/dir"
TRAIN_OBJ_ANNOTATIONS_FILE="path/to/train/file"
OUTPUT_DIR="path/to/output/dir"
VAL_IMAGE_DIR="path/to/test/images/dir"
VAL_OBJ_ANNOTATIONS_FILE="path/to/test/images/dir"

function create_train_dataset(){
python3 create_coco_tf_record.py
--logtostderr
--include_masks
--image_dir="${TRAIN_IMAGE_DIR}"
--object_annotations_file="$ {TRAIN_OBJ_ANNOTATIONS_FILE}"
--output_file_prefix="${OUTPUT_DIR}/train"
--num_shards=256
}
function create_val_dataset() {
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
PYTHONPATH="tf-models:tf-models/research"
python3 $SCRIPT_DIR/create_coco_tf_record.py
--logtostderr
--include_masks
--image_dir="${VAL_IMAGE_DIR}"
--object_annotations_file="$ {VAL_OBJ_ANNOTATIONS_FILE}"
--output_file_prefix="${OUTPUT_DIR}/val"
--num_shards=32
}

create_train_dataset
create_val_dataset

​
3. 学習を実行

```shell
MODEL_DIR="<path to the directory to store model files>"
TRAIN_FILE_PATTERN="<path to the TFRecord training data>"
EVAL_FILE_PATTERN="<path to the TFRecord validation data>"
VAL_JSON_FILE="<path to the validation annotation JSON file>"
RESNET_CHECKPOINT="<path to trained model>"
python ~/tpu/models/official/detection/main.py \
  --model="retinanet" \
  --model_dir="${MODEL_DIR?}" \
  --mode=train \
  --eval_after_training=True \
  --use_tpu=False \
  --params_override="{train: { checkpoint: { path: ${RESNET_CHECKPOINT?}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} }}"
  • 学習実行時のログで実行時間が把握できる
INFO:tensorflow:examples/sec: 0.622754
INFO:tensorflow:global_step/sec: 0.078258

4.推論

  1. 実行
python ~/tpu/models/official/detection/inference.py \
    --model="retinanet" \
    --image_size=640\
    --checkpoint_path="path/to/input" \
    --label_map_file="path/to/label" \
    --image_file_pattern="path/to/input/file" \
    --output_html="path/to/output/file" \
    --max_boxes_to_draw=10 \
    --min_score_threshold=0.05
  • オリジナルデータでの学習前のモデル
    (推論結果をアップ予定)
  • 学習後のモデル
    (推論結果をアップ予定)

5. 評価

  1. 実行
  • config_fileは学習の際に出力されるparams.yamlを使用

    python ${RETINA_ROOT}/evaluate_model.py\
      --model="retinanet"\
      --checkpoint_path="path/to/imput/file"\
      --config_file="${CONFIG_PATH}"\
      --params_override="${PARAMS_PATH}"\ 
      --dump_predictions_only = True\
      --predictions_path="path/to/output/file"
    


おまけ: 性能比較

せっかくなので過去に構築した他のモデルと比較してみた。
いずれもバッチサイズ8で、オリジナルのデータセットで学習・評価した。

学習速度

100イテレーションに費やした時間の平均を比較する。

時間
retinanet 約21[min]
ttfnet 約228[min]

精度

2000イテレーションの時のAPと推論結果から精度を比較する。

mAP
retinanet 96.35
ttfnet 79.78
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?