本文将利用 TorchVision Faster R-CNN 預訓練模型,于 Kaggle: 全球小麥檢測 🌾 上實踐遷移學習中的一種常用技術:微調(fine tuning)。
本文相關的 Kaggle Notebooks 可見:
TorchVision Faster R-CNN Finetuning
TorchVision Faster R-CNN Inference
如果你沒有 GPU ,也可于 Kaggle 上線上訓練。使用介紹:
Use Kaggle Notebooks
那麼,我們開始吧 💪
Kaggle: 全球小麥檢測 <code>Data</code> 頁下載下傳資料,内容如下:
train.csv - the training data
sample_submission.csv - a sample submission file in the correct format
train.zip - training images
test.zip - test images
讀取 <code>train.csv</code> 内容:

image_id - the unique image ID
width, height - the width and height of the images
bbox - a bounding box, formatted as a Python-style list of [xmin, ymin, width, height]
etc.
把 <code>bbox</code> 替換成 <code>x</code> <code>y</code> <code>w</code> <code>h</code>:
訓練資料大小:
(147793, 8)
唯一 <code>image_id</code> 數量:
3373
<code>train</code> 目錄下圖檔數量:
3423
說明有 <code>3422-3373=49</code> 張圖檔沒有标注。
訓練資料,圖檔大小:
(array([1024]), array([1024]))
都是 <code>1024x1024</code> 的。
檢視标注數量的分布情況:
number of boxes, range [1, 116]
一張圖最多的有 <code>116</code> 個标注。
檢視标注坐标和寬高的分布情況:
把資料集分為訓練集和驗證集,比例 <code>8:2</code>:
((122577, 10), (25216, 10))
定義下輔助函數:
預覽圖檔,不加标注:
預覽圖檔,加上标注:
繼承 <code>torch.utils.data.Dataset</code> 抽象類,實作 <code>__len__</code> <code>__getitem__</code> 。且 <code>__getitem__</code> 傳回資料為:
image: a <code>numpy.ndarray</code> image
target: a dict containing the following fields
<code>boxes</code> (<code>FloatTensor[N, 4]</code>): the coordinates of the <code>N</code> bounding boxes in <code>[x0, y0, x1, y1]</code> format, ranging from <code>0</code> to <code>W</code> and <code>0</code> to <code>H</code>
<code>labels</code> (<code>Int64Tensor[N]</code>): the label for each bounding box
<code>image_id</code> (<code>Int64Tensor[1]</code>): an image identifier. It should be unique between all the images in the dataset, and is used during evaluation
<code>area</code> (<code>Tensor[N]</code>): The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.
<code>iscrowd</code> (<code>UInt8Tensor[N]</code>): instances with <code>iscrowd=True</code> will be ignored during evaluation.
albumentations 是一個優秀的圖像增強的庫,用它定義了 <code>train</code> <code>valid</code> 的轉換方法。
現在建立 <code>train</code> <code>valid</code> 資料集:
預覽下資料集裡的圖檔:
建立一個 Faster R-CNN 預訓練模型:
輸出模型最後一層:
替換該層,指明輸出特征大小為 <code>2</code>:
再次輸出模型最後一層:
這裡我們從頭準備資料,再載入模型,進行預測。
用于送出結果的檔案。一行内容,表示一個圖檔的預測結果。如下:
<code>ce4833752,0.5 0 0 100 100</code>
<code>image_id</code> <code>ce4833752</code> 的圖檔,預測出 <code>x y w h</code> <code>0 0 100 100</code> 處是小麥,置信度 <code>0.5</code>。如果有多個預測框,能空格分隔。
執行個體化測試資料集:
{'image_id': 'aac893a91', 'PredictionString': '0.9928 72 2 96 166 0.9925 553 528 123 203 0.9912 613 921 85 102 0.9862 691 392 125 193 0.9855 819 708 105 204 0.9842 356 531 100 88 0.982 586 781 100 119 0.9795 739 768 82 122 0.9779 324 662 126 160 0.9764 27 454 102 156 0.9763 545 76 145 182 0.9736 450 858 90 95 0.9626 241 91 137 146 0.9406 306 0 75 68 0.9404 89 618 128 80 0.9366 177 576 114 182 0.9363 234 845 144 91 0.9265 64 857 115 69 0.824 822 630 90 124 0.7516 815 921 134 100'}
這就是 baseline 了,可以試着繼續調優 😊
TorchVision Instance Segmentation Finetuning Tutorial
Kaggle: Global Wheat Detection
Pytorch Starter - FasterRCNN Train
Global Wheat Detection: Starter EDA
GoCoding 個人實踐的經驗分享,可關注公衆号!