Tensorflow object detection的安裝請參考連結【Tensorflow】安裝tensorflow object detection API。
1. 下載下傳VOC資料集
到官網下載下傳VOC資料集。資料集的目錄結構如下:
2. 制作tfrecord
在models/research/object_detection/dataset_tools下有一個create_pascal_tf_record.py腳本,運作這個腳本可以直接将VOC資料集轉換成tfrecord格式的資料。
python create_pascal_tf_record.py --data_dir=/home/data/VOCdevkit --year=2012 --set=train --output_path=./data/pascal_train.record
python create_pascal_tf_record.py --data_dir=/home/data/VOCdevkit --year=2012 --set=val --output_path=./data/pascal_val.record
在data目錄下生成了兩個record檔案。
3. 下載下傳預訓練權重
下載下傳位址
4. 修改config檔案
在models/research/object_detection/samples/configs/目錄下将ssd_mobilenet_v2_coco.config複制一份重命名為ssd_mobilenet_v2_pascal.config,修改以下幾個地方:
第9行
num_classes: 90
修改為
num_classes: 20
156行
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
替換成自己的路徑
fine_tune_checkpoint: "/home/models/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt"
第173行
train_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record-?????-of-00100"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}
修改為
train_input_reader: {
tf_record_input_reader {
input_path: "/home/models/research/object_detection/data/pascal_train.record"
}
label_map_path: "home/models/reseach/object_detection/data/pascal_label_map.pbtxt"
}
修改182行num_examples你驗證集的圖像數量。
第187行
val_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record-?????-of-00010"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
shuffle: false
num_readers:1
}
修改為
val_input_reader: {
tf_record_input_reader {
input_path: "/home/models/research/object_detection/data/pascal_train.record"
}
label_map_path: "home/models/reseach/object_detection/data/pascal_label_map.pbtxt"
shuffle: false
num_readers:1
}
5. 訓練
python object_detection/model_main.py --logtostderr --pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config --model_dir=/home/data/VOCdekit/ssd_mobilnet_v2 --num_train_steps=50000 --num_eval_steps=500
6. 訓練可視化
新版的代碼在訓練時terminal可能會卡住沒有輸出,不過沒關系,可以在tensorboard中檢視訓練情況。
tensorboard --logdir=/home/data/VOCdekit/ssd_mobilnet_v2
把終端輸出的http://xxxxxx複制到浏覽器中打開
7. 固化權重
python object_detection/export_inference_graph.py --input_type=image_tensor --pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config --trained_checkpoint_prefix=/home/data/VOCdekit/ssd_mobilnet_v2/model.ckpt-50000 --output_directory=/home/data/VOCdekit/ssd_mobilnet_v2_pascal
生成如下檔案