天天看點

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

以下文章來源于AI科技大學營 ,作者Alan Bi

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

作者 | Alan Bi

譯者 | 武明利,責編 | Carol

出品 | AI科技大學營(ID:rgznai100)

如今,機器學習和計算機視覺已成為一種熱潮。我們都看過關于自動駕駛汽車和面部識别的新聞,可能會想象建立自己的計算機視覺模型有多酷。然而,進入這個領域并不總是那麼容易,尤其是在沒有很強的數學背景的情況下。如果你隻想做一些小的實驗,像PyTorch和TensorFlow這樣的庫可能會很枯燥。

在本教程中,作者提供了一種簡單的方法,任何人都可以使用幾行代碼建構全功能的對象檢測模型。更具體地說,我們将使用Detecto,這是一個在PyTorch之上建構的Python軟體包,可簡化該過程并向所有級别的程式員開放。

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

快速簡單的例子

為了示範如何簡單地使Detecto,讓我們加載一個預先訓練的模型,并對以下圖像進行推斷:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

首先,使用pip下載下傳Detecto軟體包:

pip3 install detecto

然後,将上面的圖像另存為“fruit.jpg”,并在與圖像相同的檔案夾中建立一個Python檔案。在Python檔案中,編寫以下5行代碼:

from detectoimport core, utils, visualize           

複制

複制

image = utils.read_image('fruit.jpg')           

複制

model = core.Model()           

複制

複制

labels, boxes, scores = model.predict_top(image)           

複制

visualize.show_labeled_image(image, boxes, labels)           

複制

運作此檔案後(如果你的計算機上沒有啟用CUDA的GPU,可能會花費幾秒鐘;稍後再進行介紹),你應該會看到類似下面的圖:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

作者僅用了5行代碼就完成了所有工作,真的是太棒了。下面是我們每步中分别做的:

1)導入Detecto子產品

2)讀入圖像

3)初始化預訓練模型

4)在圖像上生成最高預測

5)為預測繪圖

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

繪制我們的預測

Detecto使用來自PyTorch模型動物園中的Faster R-CNN ResNet-50 FPN,它能夠檢測大約80種不同的物體,例如動物,車輛,廚房用具等。但是,如果你想要檢測自定義對象,例如可口可樂與百事可樂罐,斑馬與長頸鹿,該怎麼辦呢?

這時你會發現,在自定義資料集上訓練探測器模型同樣簡單; 同樣,你隻需要5行代碼,以及現有的資料集或花一些時間标記圖像。

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

建構自定義資料集

在本教程中,作者将從頭開始建構自己的資料集。建議你也這樣做,但是如果你想跳過這一步,你可以在這裡下載下傳一個示例資料集(從斯坦福的Dog資料集修改)。

對于我們的資料集,我們将訓練我們的模型來檢測來自RoboSub競賽的水下外星人,蝙蝠和女巫,如下所示:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

理想情況下,每個類至少需要100張圖像。好在每張圖像中可以有多個對象,是以理論上,如果每張圖像包含你想要檢測的每類對象,那麼你可以總共獲得100張圖像。另外,如果你有視訊素材,Detico可以輕松地将這些視訊素材分割成可用于資料集的圖像:

from detecto.utilsimport split_video           

複制

複制

split_video('video.mp4','frames/', step_size=4)           

複制

上面的代碼在“video.mp4”中每第4幀拍攝一次,并将其另存為JPEG檔案存在“frames”檔案夾中。

生成訓練資料集後,應該具有一個類似于以下内容的檔案夾:

images/           

複制

|   image0.jpg           

複制

|   image1.jpg           

複制

|   image2.jpg           

複制

|   ...           

複制

如果需要的話,你還可以使用另一個檔案夾,其中包含一組驗證圖像。

現在是耗時的部分:标記。Detecto支援PASCAL VOC格式,其中具有XML檔案,其中包含圖像中每個對象的标簽和位置資料。要建立這些XML檔案,可以使用開源LabelImg工具,如下所示:

pip3 install labelImg   # Download LabelImg using pip           

複制

labelImg                # Launch the application           

複制

現在,你應該會看到一個彈出視窗。單擊左側“打開目錄”按鈕,然後選擇想要标記的圖像檔案夾。如果一切正常,你應該會看到類似以下内容:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

要繪制邊界框,請單擊左側菜單欄中的圖示(或使用鍵盤快捷鍵“w”)。然後,你可以在對象周圍拖動一個框并編寫/選擇标簽:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

标記完圖像後,請使用CTRL+S或CMD+S儲存XML檔案(為簡便起見,你可以使用自動填充的預設檔案位置和名稱)。要标記下一張圖像,請單擊“下一張圖像”(或使用鍵盤快捷鍵“d”)。

整個資料集處理完畢之後,你的檔案夾應如下所示:

images/           

複制

|   image0.jpg           

複制

|   image0.xml           

複制

|   image1.jpg           

複制

|   image1.xml           

複制

|   ...           

複制

我們已經準備好開始訓練我們的對象檢測模型了!

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

通路GPU

首先,檢查你的計算機是否具有啟用CUDA的GPU。由于深度學習需要大量處理能力,是以在通常的CPU上進行訓練可能會非常緩慢。值得慶幸的是,大多數現代深度學習架構(例如PyTorch和Tensorflow)都可以在GPU上運作,進而使處理速度更快。確定已經下載下傳了PyTorch(如果你安裝了Detecto,應該已經下載下傳了),然後運作以下兩行代碼:

import torch           

複制

複制

print(torch.cuda.is_available())           

複制

如果列印True,那你可以跳到下一部分。如果顯示False,不要擔心。請按照以下步驟建立Google Colaboratory筆記本,這是一個線上編碼環境,帶有免費可用的GPU。對于本教程,你将隻在Google Drive檔案夾中工作,而不是在計算機上工作。

1)登入到Google Drive

2)建立一個名為“Detecto Tutorial”的檔案夾并導航到該檔案夾

3)将你的訓練圖像(和/或驗證圖像)上傳到此檔案夾

4)右鍵單擊,轉到“更多”,然後單擊“Google Colaboratory”:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

你現在應該看到這樣的界面:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

5)根據需要給筆記本起個名字,然後轉到“編輯”->“筆記本設定”->“硬體加速器”,然後選擇“GPU”

6)輸入以下代碼以“裝入”你的雲端硬碟,将目錄更改為目前檔案夾,然後安裝Detecto:

import os           

複制

from google.colabimport drive           

複制

drive.mount('/content/drive')           

複制

os.chdir('/content/drive/My Drive/Detecto Tutorial')           

複制

!pip install detecto           

複制

為了確定一切正常,你可以建立一個新的代碼單元,然後輸入!ls以檢查你是否處于正确的目錄中。

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

訓練自定義模型

最後,我們現在可以在自定義資料集上訓練模型了。如前所述,這是容易的部分。它隻需要4行代碼:

from detectoimport core, utils, visualize           

複制

dataset = core.Dataset('images/')           

複制

model = core.Model(['alien','bat','witch'])           

複制

model.fit(dataset)           

複制

讓我們再次分解一下我們每行代碼所做的工作:

1、導入的Detecto子產品

2、從“images”檔案夾(包含我們的JPEG和XML檔案)建立了一個資料集

3、初始化模型檢測自定義對象(外星人,蝙蝠和女巫)

4、在資料集上訓練我們的模型

根據資料集的大小,這可能需要10分鐘到1個小時以上的時間來運作,是以請確定你的程式在完成上述語句後不會立即退出(例如:你使用的是Jupyter / Colab筆記本,它在活動時保留狀态)。

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

使用訓練好的模型

現在你已經有了訓練好的模型,讓我們在一些圖像上對其進行測試。要從檔案路徑讀取圖像,可以使用detecto.utils子產品中的read_image函數(也可以使用上面建立的資料集中的圖像):

# Specify the path to your image           

複制

image = utils.read_image('images/image0.jpg')           

複制

predictions = model.predict(image)           

複制

# predictions format: (labels, boxes, scores)           

複制

labels, boxes, scores = predictions           

複制

複制

# ['alien', 'bat', 'bat']           

複制

print(labels)           

複制

複制

#           xmin       ymin       xmax       ymax           

複制

# tensor([[ 569.2125,  203.6702, 1003.4383,  658.1044],           

複制

#         [ 276.2478,  144.0074,  579.6044,  508.7444],           

複制

#         [ 277.2929,  162.6719,  627.9399,  511.9841]])           

複制

print(boxes)           

複制

# tensor([0.9952, 0.9837, 0.5153])           

複制

print(scores)           

複制

正像你看到的,模型的預測方法傳回一個由3個元素組成的元組:标簽,方框和分數。在上面的示例中,此模型在坐标[569、204、1003、658](框[0])處預測了一個外星人(标簽[0]),其置信度為0.995(得分[0])。

根據這些預測,我們可以使用detecto.visualize子產品繪制結果。例如:

visualize.show_labeled_image(image, boxes, labels)           

複制

将上面的代碼與收到的圖像和預測一起運作将産生如下所示的内容:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

如果你有一個視訊,你可以在它上面運作對象檢測:

visualize.detect_video(model,'input.mp4','output.avi')           

複制

這将擷取一個名為“input.mp4”的視訊檔案,并根據給定模型的預測結果生成一個“output.avi”檔案。如果你使用VLC或其他視訊播放器打開此檔案,應該會看到一些希望看到的結果!

最後,你可以從檔案中儲存和加載模型,進而可以儲存進度并稍後傳回:

model.save('model_weights.pth')           

複制

複制

# ... Later ...           

複制

複制

model = core.Model.load('model_weights.pth', ['alien','bat','witch'])           

複制

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

進階用法

你會發現Detecto不僅限于5行代碼。舉例來說,這個模型沒有你希望的那麼好。我們可以嘗試通過使用Torchvision轉換來擴充我們的資料集并定義一個自定義資料加載器來提高其性能:

from torchvisionimport transforms           

複制

augmentations = transforms.Compose([           

複制

transforms.ToPILImage(),           

複制

transforms.RandomHorizontalFlip(0.5),           

複制

transforms.ColorJitter(saturation=0.5),           

複制

transforms.ToTensor(),           

複制

utils.normalize_transform(),           

複制

])           

複制

複制

dataset = core.Dataset('images/', transform=augmentations)           

複制

複制

loader = core.DataLoader(dataset, batch_size=2, shuffle=True)           

複制

此代碼對資料集中的圖像應用了随機的水準翻轉和飽和效果,進而增加了資料的多樣性。然後,我們使用batch_size = 2定義一個資料加載對象;我們将其傳遞給model.fit而不是Dataset,這樣來告訴我們的模型是對2張圖像進行批量訓練,而不是預設的1張。

如果你之前建立了單獨的驗證資料集,那麼現在是在訓練期間加載它的時候了。通過提供驗證資料集,fit方法将傳回每個時期的損失清單,如果verbose = True,則會在訓練過程中将其列印出來。以下代碼塊示範了這一點,并自定義了其他幾個訓練參數:

import matplotlib.pyplotas plt           

複制

複制

val_dataset = core.Dataset('validation_images/')           

複制

複制

losses = model.fit(loader, val_dataset, epochs=10, learning_rate=0.001,           

複制

lr_step_size=5, verbose=True)           

複制

plt.plot(losses)           

複制

plt.show()           

複制

損失的結果圖應或多或少地減少:

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

為了更具有靈活性和對模型的控制,你可以完全繞過Detecto。你可以根據需要随意調整model.get_internal_model方法傳回使用的基礎模型。

幾行代碼建構全功能的對象檢測模型,他是如何做到的?

結論

在本教程中,作者展示了打造計算機視覺和對象檢測并沒有多大的挑戰性。你所需要的是一點時間和耐心來處理标記的的數集。

如果你對進一步探索感興趣的話,請檢視Detecto on GitHub或通路文檔以擷取更多教程和用例!

原文:

https://hackernoon.com/build-a-custom-trained-object-detection-model-with-5-lines-of-code-y08n33vi