垃圾分類資料集+垃圾分類識别訓練代碼(Pytorch)
目錄
垃圾分類資料集+垃圾分類識别訓練代碼(Pytorch)
1. 前言
2. 垃圾資料集說明
(1)垃圾資料集dataset1
(2)垃圾資料集dataset2
3. 垃圾分類識别模型訓練
(1)項目架構說明
(2)準備Train和Test資料
(3)配置檔案:config.yaml
(4)開始訓練
(5)可視化訓練過程
(6)一些優化建議
4. 垃圾分類識别模型測試效果
5.項目源碼下載下傳
1. 前言
垃圾分類,指按一定規定或标準将垃圾分類儲存、分類投放和分類搬運,進而轉變成公共資源的一系列活動的總稱。分類的目的是提高垃圾的資源價值和經濟價值,力争物盡其用。智能化垃圾分類系統能能夠加速綠色環保的垃圾處理過程,并且對于居民垃圾分類意識的養成有極大的促進作用,對綠色都市和智能化城市管理都有着重大意義。
本項目将采用深度學習的方法,搭建一個垃圾分類識别的訓練和測試系統,實作智能化垃圾分類。目前,基于ResNet18的垃圾分類識别,在垃圾資料集dataset2,訓練集的Accuracy在94%左右,測試集的Accuracy在92%左右,如果想進一步提高準确率,可以嘗試:
- 最重要的: 清洗資料集,垃圾資料集dataset1和垃圾資料集dataset2,大部分資料都是網上爬取的,品質并不高,存在很多錯誤的圖檔,盡管鄙人已經清洗一部分了,但還是建議你,訓練前,再次清洗資料集,不然會影響模型的識别的準确率。
- 使用不同backbone模型,比如resnet50或者更深模型
- 增加資料增強: 已經支援: 随機裁剪,随機翻轉,随機旋轉,顔色變換等資料增強方式,可以嘗試諸如mixup,CutMix等更複雜的資料增強方式
- 樣本均衡: 目前訓練代碼已經支援樣本重采樣,設定resample=True即可
- 調超參: 比如學習率調整政策,優化器(SGD,Adam等)
- 損失函數: 目前訓練代碼已經支援:交叉熵,LabelSmoothing,可以嘗試FocalLoss等損失函數
【源碼下載下傳】垃圾分類資料集+垃圾分類識别訓練代碼(Pytorch)
2. 垃圾資料集說明
(1)垃圾資料集dataset1
這是通過網上爬蟲擷取的垃圾資料集,總共包含了可回收物(recyclables)、有害垃圾(hazardous)、廚餘垃圾(kitchen)、其他垃圾(other)四大類,以及40個小類。其中可回收物(recyclables)23種、有害垃圾(hazardous)3種、廚餘垃圾(kitchen)8種、其他垃圾(other)6種,每種垃圾包含大約400張圖檔,共1.7萬餘張圖像。
其中Train集16200張圖檔,平均每個類别405張;Test集800張圖檔,平均每個類别20張圖檔
垃圾類别 | 樣圖 |
可回收物(recyclables) | |
有害垃圾(hazardous) | |
廚餘垃圾(kitchen) | |
其他垃圾(other) |
下表給出垃圾資料集dataset1的40個類别:
0-other garbage-fast food box
1-other garbage-soiled plastic
2-other garbage-cigarette
3-other garbage-toothpick
4-other garbage-flowerpot
5-other garbage-bamboo chopsticks
6-kitchen waste-meal
7-kitchen waste-bone
8-kitchen waste-fruit peel
9-kitchen waste-pulp
10-kitchen waste-tea
11-kitchen waste-Vegetable
12-kitchen waste-eggshell
13-kitchen waste-fish bone
14-recyclables-powerbank
15-recyclables-bag
16-recyclables-cosmetic bottles
17-recyclables-toys
18-recyclables-plastic bowl
19-recyclables-plastic hanger
20-recyclables-paper bags
21-recyclables-plug wire
22-recyclables-old clothes
23-recyclables-can
24-recyclables-pillow
25-recyclables-plush toys
26-recyclables-shampoo bottle
27-recyclables-glass cup
28-recyclables-shoes
29-recyclables-anvil
30-recyclables-cardboard
31-recyclables-seasoning bottle
32-recyclables-bottle
33-recyclables-metal food cans
34-recyclables-pot
35-recyclables-edible oil barrel
36-recyclables-drink bottle
37-hazardous waste-dry battery
38-hazardous waste-ointment
39-hazardous waste-expired drugs
(2)垃圾資料集dataset2
該垃圾資料集是隻包含兩個大類,沒有細分小類,其中類别Organic表示有機垃圾,類别Recycle表示可回收垃圾。Train集22,566張圖檔,Test集2,513張圖檔。
垃圾類别 | 樣圖 |
Organic有機垃圾 | |
Recycle回收垃圾 |
垃圾資料集dataset1和垃圾資料集dataset2,大部分資料都是網上爬取的,品質并不高,存在很多錯誤的圖檔,盡管鄙人已經清洗一部分了,但還是建議你,訓練前,再次清洗資料集,不然會影響模型的識别的準确率。
3. 垃圾分類識别模型訓練
(1)項目架構說明
整套工程基本架構結構如下:
.
├── classifier # 訓練模型相關工具
├── configs # 訓練配置檔案
├── data # 訓練資料
├── libs
├── demo.py # 模型推理demo
├── README.md # 項目工程說明文檔
├── requirements.txt # 項目相關依賴包
└── train.py # 訓練檔案
(2)準備Train和Test資料
下載下傳垃圾分類資料集,Train和Test資料集,要求相同類别的圖檔,放在同一個檔案夾下;且子目錄檔案夾命名為類别名稱。
資料增強方式主要采用: 随機裁剪,随機翻轉,随機旋轉,顔色變換等處理方式
import numbers
import random
import PIL.Image as Image
import numpy as np
from torchvision import transforms
def image_transform(input_size, rgb_mean=[0.5, 0.5, 0.5], rgb_std=[0.5, 0.5, 0.5], trans_type="train"):
"""
不推薦使用:RandomResizedCrop(input_size), # bug:目标容易被crop掉
:param input_size: [w,h]
:param rgb_mean:
:param rgb_std:
:param trans_type:
:return::
"""
if trans_type == "train":
transform = transforms.Compose([
transforms.Resize([int(128 * input_size[1] / 112), int(128 * input_size[0] / 112)]),
transforms.RandomHorizontalFlip(), # 随機左右翻轉
# transforms.RandomVerticalFlip(), # 随機上下翻轉
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
transforms.RandomRotation(degrees=5),
transforms.RandomCrop([input_size[1], input_size[0]]),
transforms.ToTensor(),
transforms.Normalize(mean=rgb_mean, std=rgb_std),
])
elif trans_type == "val" or trans_type == "test":
transform = transforms.Compose([
transforms.Resize([input_size[1], input_size[0]]),
# transforms.CenterCrop([input_size[1], input_size[0]]),
# transforms.Resize(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=rgb_mean, std=rgb_std),
])
else:
raise Exception("transform_type ERROR:{}".format(trans_type))
return transform
修改配置檔案資料路徑:config.yaml
# 訓練資料集,可支援多個資料集
train_data:
- '/path/to/rubbish/dataset2/train'
# 測試資料集
test_data: '/path/to/rubbish/dataset2/test'
# 類别檔案
class_name: '/path/to/rubbish/dataset2/class_name.txt'
(3)配置檔案:config.yaml
- 目前支援的backbone有:googlenet,resnet[18,34,50], ,mobilenet_v2等, 其他backbone可以自定義添加
- 訓練參數可以通過(configs/config.yaml)配置檔案進行設定
配置檔案:config.yaml說明如下:
# 訓練資料集,可支援多個資料集
train_data:
- '/path/to/rubbish/dataset2/train'
# 測試資料集
test_data: '/path/to/rubbish/dataset2/test'
# 類别檔案
class_name: '/path/to/rubbish/dataset2/class_name.txt'
train_transform: "train" # 訓練使用的資料增強方法
test_transform: "val" # 測試使用的資料增強方法
resample: True # 進行樣本均衡
work_dir: "work_space/" # 儲存輸出模型的目錄
net_type: "resnet18" # 骨幹網絡,支援:resnet18,resnet50,mobilenet_v2,googlenet
width_mult: 1.0
input_size: [ 224,224 ] # 模型輸入大小
rgb_mean: [ 0.5, 0.5, 0.5 ] # for normalize inputs to [-1, 1],Sequence of means for each channel.
rgb_std: [ 0.5, 0.5, 0.5 ] # for normalize,Sequence of standard deviations for each channel.
batch_size: 32
lr: 0.01 # 初始學習率
optim_type: "SGD" # 選擇優化器,SGD,Adam
loss_type: "CrossEntropyLoss" # 選擇損失函數:支援CrossEntropyLoss,LabelSmoothing
momentum: 0.9 # SGD momentum
num_epochs: 100 # 訓練循環次數
num_warn_up: 3 # warn-up次數
num_workers: 8 # 加載資料工作程序數
weight_decay: 0.0005 # weight_decay,預設5e-4
scheduler: "multi-step" # 學習率調整政策
milestones: [ 20,50,80 ] # 下調學習率方式
gpu_id: [ 0 ] # GPU ID
log_freq: 50 # LOG列印頻率
progress: True # 是否顯示進度條
pretrained: False # 是否使用pretrained模型
finetune: False # 是否進行finetune
參數 | 類型 | 參考值 | 說明 |
train_data | str, list | - | 訓練資料檔案,可支援多個檔案 |
test_data | str, list | - | 測試資料檔案,可支援多個檔案 |
class_name | str | - | 類别檔案 |
work_dir | str | work_space | 訓練輸出工作空間 |
net_type | str | resnet18 | backbone類型,{resnet18,resnet50,mobilenet_v2,...} |
input_size | list | [128,128] | 模型輸入大小[W,H] |
batch_size | int | 32 | batch size |
lr | float | 0.1 | 初始學習率大小 |
optim_type | str | SGD | 優化器,{SGD,Adam} |
loss_type | str | CELoss | 損失函數 |
scheduler | str | multi-step | 學習率調整政策,{multi-step,cosine} |
milestones | list | [30,80,100] | 降低學習率的節點,僅僅scheduler=multi-step有效 |
momentum | float | 0.9 | SGD動量因子 |
num_epochs | int | 120 | 循環訓練的次數 |
num_warn_up | int | 3 | warn_up的次數 |
num_workers | int | 12 | DataLoader開啟線程數 |
weight_decay | float | 5e-4 | 權重衰減系數 |
gpu_id | list | [ 0 ] | 指定訓練的GPU卡号,可指定多個 |
log_freq | in | 20 | 顯示LOG資訊的頻率 |
finetune | str | model.pth | finetune的模型 |
progress | bool | True | 是否顯示進度條 |
distributed | bool | False | 是否使用分布式訓練 |
(4)開始訓練
整套訓練代碼非常簡單操作,使用者隻需要将相同類别的資料放在同一個目錄下,并填寫好對應的資料路徑,即可開始訓練了。
python train.py -c configs/config.yaml
(5)可視化訓練過程
訓練過程可視化工具是使用Tensorboard,使用方法:
# 基本方法
tensorboard --logdir=path/to/log/
# 例如
tensorboard --logdir=work_space/mobilenet_v2_1.0_CrossEntropyLoss/log
可視化效果
(6)一些優化建議
訓練完成後,在垃圾資料集dataset2訓練集的Accuracy在94%左右,測試集的Accuracy在92%左右;而在垃圾資料集dataset1的Accuracy在83%左右,如果想進一步提高準确率,可以嘗試:
- 最重要的: 清洗資料集,垃圾資料集dataset1和垃圾資料集dataset2,大部分資料都是網上爬取的,品質并不高,存在很多錯誤的圖檔,盡管鄙人已經清洗一部分了,但還是建議你,訓練前,再次清洗資料集,不然會影響模型的識别的準确率。
- 使用不同backbone模型,比如resnet50或者更深模型
- 增加資料增強: 已經支援: 随機裁剪,随機翻轉,随機旋轉,顔色變換等資料增強方式,可以嘗試諸如mixup,CutMix等更複雜的資料增強方式
- 樣本均衡: 目前訓練代碼已經支援樣本重采樣,設定resample=True即可
- 調超參: 比如學習率調整政策,優化器(SGD,Adam等)
- 損失函數: 目前訓練代碼已經支援:交叉熵,LabelSmoothing,可以嘗試FocalLoss等損失函數
4. 垃圾分類識别模型測試效果
demo.py檔案用于推理和測試模型的效果,填寫好配置檔案,模型檔案以及測試圖檔即可運作測試了
def get_parser():
# 配置檔案
config_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20220822153756/config.yaml"
# 模型檔案
model_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20220822153756/model/best_model_043_92.5587.pth"
# 待測試圖檔目錄
image_dir = "/home/dm/nasdata/dataset/csdn/rubbish/dataset2/test"
parser = argparse.ArgumentParser(description="Inference Argument")
parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)
parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)
parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)
parser.add_argument("--image_dir", help="image file or directory", default=image_dir, type=str)
return parser
python demo.py \
-c "data/pretrained/resnet18_1.0_CrossEntropyLoss_20220822153756/config.yaml" \
-m "data/pretrained/resnet18_1.0_CrossEntropyLoss_20220822153756/model/best_model_043_92.5587.pth" \
--image_dir "data/test_images/rubbish"
運作測試結果:
pred_index:['Organic'],pred_score:[0.9952668] | pred_index:['Organic'],pred_score:[0.9911327] |
pred_index:['Recycle'],pred_score:[0.73851496] |
5.項目源碼下載下傳
- 垃圾資料集dataset1,其中Train集16200張圖檔,Test集800張圖檔
- 垃圾資料集dataset2,其中Train集22,566張圖檔,Test集2,513張圖檔
- 整套垃圾分類訓練代碼和測試代碼(Pytorch版本)