天天看點

【Tensorflow】object_detection:SSD_MobileNetV2訓練VOC資料集

Tensorflow object detection的安裝請參考連結【Tensorflow】安裝tensorflow object detection API。

1. 下載下傳VOC資料集

到官網下載下傳VOC資料集。資料集的目錄結構如下:

【Tensorflow】object_detection:SSD_MobileNetV2訓練VOC資料集

2. 制作tfrecord

在models/research/object_detection/dataset_tools下有一個create_pascal_tf_record.py腳本,運作這個腳本可以直接将VOC資料集轉換成tfrecord格式的資料。

python create_pascal_tf_record.py --data_dir=/home/data/VOCdevkit --year=2012 --set=train --output_path=./data/pascal_train.record
python create_pascal_tf_record.py --data_dir=/home/data/VOCdevkit --year=2012 --set=val --output_path=./data/pascal_val.record
           

在data目錄下生成了兩個record檔案。

【Tensorflow】object_detection:SSD_MobileNetV2訓練VOC資料集
【Tensorflow】object_detection:SSD_MobileNetV2訓練VOC資料集

3. 下載下傳預訓練權重

下載下傳位址

4. 修改config檔案

在models/research/object_detection/samples/configs/目錄下将ssd_mobilenet_v2_coco.config複制一份重命名為ssd_mobilenet_v2_pascal.config,修改以下幾個地方:

第9行

num_classes: 90
           

修改為

num_classes: 20
           

156行

fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
           

替換成自己的路徑

fine_tune_checkpoint: "/home/models/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt"
           

第173行

train_input_reader: {
  tf_record_input_reader {
    input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record-?????-of-00100"
  }
  label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}
           

修改為

train_input_reader: {
  tf_record_input_reader {
    input_path: "/home/models/research/object_detection/data/pascal_train.record"
  }
  label_map_path: "home/models/reseach/object_detection/data/pascal_label_map.pbtxt"
}
           

修改182行num_examples你驗證集的圖像數量。

第187行

​
val_input_reader: {
  tf_record_input_reader {
    input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record-?????-of-00010"
  }
  label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
  shuffle: false
  num_readers:1
}​
           

修改為

​
​
val_input_reader: {
  tf_record_input_reader {
    input_path: "/home/models/research/object_detection/data/pascal_train.record"
  }
  label_map_path: "home/models/reseach/object_detection/data/pascal_label_map.pbtxt"
  shuffle: false
  num_readers:1
}​

​
           

5. 訓練

python object_detection/model_main.py --logtostderr --pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config --model_dir=/home/data/VOCdekit/ssd_mobilnet_v2 --num_train_steps=50000 --num_eval_steps=500
           
【Tensorflow】object_detection:SSD_MobileNetV2訓練VOC資料集

6. 訓練可視化

新版的代碼在訓練時terminal可能會卡住沒有輸出,不過沒關系,可以在tensorboard中檢視訓練情況。

tensorboard --logdir=/home/data/VOCdekit/ssd_mobilnet_v2
           

把終端輸出的http://xxxxxx複制到浏覽器中打開

【Tensorflow】object_detection:SSD_MobileNetV2訓練VOC資料集

7. 固化權重

python object_detection/export_inference_graph.py --input_type=image_tensor --pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config --trained_checkpoint_prefix=/home/data/VOCdekit/ssd_mobilnet_v2/model.ckpt-50000 --output_directory=/home/data/VOCdekit/ssd_mobilnet_v2_pascal
           

生成如下檔案

【Tensorflow】object_detection:SSD_MobileNetV2訓練VOC資料集

繼續閱讀