天天看點

TorchVision 預訓練模型進行推斷

torchvision.models 裡包含了許多模型,用于解決不同的視覺任務:圖像分類、語義分割、物體檢測、執行個體分割、人體關鍵點檢測和視訊分類。

本文将介紹 torchvision 中模型的入門使用,一起來建立 Faster R-CNN 預訓練模型,預測圖像中有什麼物體吧。

<code>print(model)</code> 可檢視其結構:

此預訓練模型是于 COCO train2017 上訓練的,可預測的分類有:

擷取支援的 <code>device</code>:

模型移到 <code>device</code>:

準備模型入參 <code>images</code>:

例圖 <code>data/bicycle.jpg</code>:

TorchVision 預訓練模型進行推斷

模型切為 <code>eval</code> 模式:

模型在推斷時,隻需要給到圖像資料,不用标注資料。推斷後,會傳回每個圖像的預測結果 <code>List[Dict[Tensor]]</code>。<code>Dict</code> 包含字段有:

boxes (<code>FloatTensor[N, 4]</code>): 預測框 <code>[x1, y1, x2, y2]</code>, <code>x</code> 範圍 <code>[0,W]</code>, <code>y</code> 範圍 <code>[0,H]</code>

labels (<code>Int64Tensor[N]</code>): 預測類别

scores (<code>Tensor[N]</code>): 預測評分

預測結果如下:

擷取 <code>score &gt;= 0.9</code> 的預測結果:

引入 <code>utils.plots.plot_image</code> 繪制結果:

<code>utils.plots.plot_image</code> 函數實作可見後文源碼,注意其要求 <code>torchvision &gt;= 0.9.0/nightly</code>。

test_pretrained_models.py

<code>utils.colors.golden</code>:

<code>utils.plots.plot_image</code>:

torch.hub

torchvision.models

GoCoding 個人實踐的經驗分享,可關注公衆号!

繼續閱讀